waynesun 5 years ago
parent
commit
f01fa37d85
2 changed files with 8 additions and 8 deletions
  1. 2 2
      README.md
  2. 6 6
      sqlnet/utils.py

+ 2 - 2
README.md

@@ -18,7 +18,7 @@ The difference between SQLNet and this baseline model is, Select-Number and Wher
 
 ## Start to train
 
-Firstly, download the provided datasets at ~/data_nl2sql/, which includes train.json, train.tables.json, dev.json, dev.tables.json and char_embedding.
+Firstly, download the provided datasets at ~/data_nl2sql/, which should include train.json, train.tables.json, dev.json, dev.tables.json and char_embedding, and divide them in following structure.
 ```
 ├── data
 │ ├── train
@@ -29,7 +29,7 @@ Firstly, download the provided datasets at ~/data_nl2sql/, which includes train.
 │ │ ├── dev.tables.json
 ├── char_embedding
 ```
-
+and then
 ```
 mkdir ~/nl2sql
 cd ~/nl2sql/

+ 6 - 6
sqlnet/utils.py

@@ -36,15 +36,15 @@ def load_data(sql_paths, table_paths, use_small=False):
 
 def load_dataset(toy=False, use_small=False, mode='train'):
     print "Loading dataset"
-    dev_sql, dev_table = load_data('data/dev.json', 'data/dev.tables.json', use_small=use_small)
-    dev_db = 'data/dev.db'
+    dev_sql, dev_table = load_data('data/dev/dev.json', 'data/dev/dev.tables.json', use_small=use_small)
+    dev_db = 'data/dev/dev.db'
     if mode == 'train':
-        train_sql, train_table = load_data('data/train.json', 'data/train.tables.json', use_small=use_small)
-        train_db = 'data/train.db'
+        train_sql, train_table = load_data('data/train/train.json', 'data/train/train.tables.json', use_small=use_small)
+        train_db = 'data/train/train.db'
         return train_sql, train_table, train_db, dev_sql, dev_table, dev_db
     elif mode == 'test':
-        test_sql, test_table = load_data('data/test.json', 'data/test.tables.json', use_small=use_small)
-        test_db = 'data/test.db'
+        test_sql, test_table = load_data('data/test/test.json', 'data/test/test.tables.json', use_small=use_small)
+        test_db = 'data/test/test.db'
         return dev_sql, dev_table, dev_db, test_sql, test_table, test_db
 
 def to_batch_seq(sql_data, table_data, idxes, st, ed, ret_vis_data=False):