Browse Source

support execution accuracy

waynesun 4 years ago
parent
commit
02373c9c25
5 changed files with 78 additions and 62 deletions
  1. 12 1
      README.md
  2. 39 25
      sqlnet/lib/dbengine.py
  3. 16 33
      sqlnet/utils.py
  4. 4 1
      test.py
  5. 7 2
      train.py

+ 12 - 1
README.md

@@ -1,3 +1,7 @@
+***** New June 18st, 2019 *****
+
+This version of release supports execution accuracy, which gets the execution result of predicted SQL. This requires records==0.5.3 before running.
+
 ## Introduction
 
 This baseline method is developed and refined based on <a href="https://github.com/xiaojunxu/SQLNet">code</a> of <a href="https://arxiv.org/abs/1711.04436">SQLNet</a>, which is a baseline model in <a href="https://github.com/salesforce/WikiSQL">WikiSQL</a>.
@@ -14,6 +18,7 @@ The difference between SQLNet and this baseline model is, Select-Number and Wher
 
  - Python 2.7
  - torch 1.0.1
+ - records 0.5.3
  - tqdm
 
 ## Start to train
@@ -22,11 +27,17 @@ Firstly, download the provided datasets at ~/data_nl2sql/, which should include
 ```
 ├── data_nl2sql
 │ ├── train
+│ │ ├── train.db
 │ │ ├── train.json
 │ │ ├── train.tables.json
 │ ├── val
+│ │ ├── val.db
 │ │ ├── val.json
 │ │ ├── val.tables.json
+│ ├── test
+│ │ ├── test.db
+│ │ ├── test.json
+│ │ ├── test.tables.json
 │ ├── char_embedding
 ```
 and then
@@ -53,7 +64,7 @@ while the first parameter 0 means gpu number, the second parameter means the out
 
 ## Experiment result
 
-We have run experiments several times, achiving avegrage 27.5% logic form accuracy on the val dataset.
+We have run experiments several times, achiving avegrage 27.5% logic form accuracy on the val dataset, with 128 batch size.
 
 
 ## Experiment analysis

+ 39 - 25
sqlnet/lib/dbengine.py

@@ -6,47 +6,61 @@ from babel.numbers import parse_decimal, NumberFormatError
 schema_re = re.compile(r'\((.+)\)')
 num_re = re.compile(r'[-+]?\d*\.\d+|\d+')
 
-agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
-cond_ops = ['=', '>', '<', 'OP']
+agg_dict = {0:"", 1:"AVG", 2:"MAX", 3:"MIN", 4:"COUNT", 5:"SUM"}
+cond_op_dict = {0:">", 1:"<", 2:"==", 3:"!="}
+cond_rela_dict = {0:"and",1:"or",-1:""}
 
 class DBEngine:
 
     def __init__(self, fdb):
-        #fdb = 'data/test.db'
         self.db = records.Database('sqlite:///{}'.format(fdb))
+        self.conn = self.db.get_connection()
 
-    def execute_query(self, table_id, query, *args, **kwargs):
-        return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs)
+    def execute(self, table_id, select_index, aggregation_index, conditions, condition_relation, lower=True):
+        if not table_id.startswith('Table'):
+            table_id = 'Table_{}'.format(table_id.replace('-', '_'))
+        wr = ""
+        if condition_relation == 1 or condition_relation == 0:
+            wr = " AND "
+        elif condition_relation == 2:
+            wr = " OR "
 
-    def execute(self, table_id, select_index, aggregation_index, conditions, lower=True):
-        if not table_id.startswith('table'):
-            table_id = 'table_{}'.format(table_id.replace('-', '_'))
-        table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
+        table_info = self.conn.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql
         schema_str = schema_re.findall(table_info)[0]
         schema = {}
-        for tup in schema_str.split(', '):
-            c, t = tup.split()
+        for tup in schema_str.split(','):
+            c, t = tup.split(' ')
             schema[c] = t
-        select = 'col{}'.format(select_index)
-        agg = agg_ops[aggregation_index]
-        if agg:
-            select = '{}({})'.format(agg, select)
+
+        tmp = ""
+        for sel, agg in zip(select_index, aggregation_index):
+            select_str = 'col_{}'.format(sel+1)
+            agg_str = agg_dict[agg]
+            if agg:
+                tmp += '{}({}),'.format(agg_str, select_str)
+            else:
+                tmp += '({}),'.format(select_str)
+        tmp = tmp[:-1]
+
         where_clause = []
         where_map = {}
         for col_index, op, val in conditions:
-            if lower and (isinstance(val, str) or isinstance(val, unicode)):
+            if lower and (isinstance(val, str) or isinstance(val, str)):
                 val = val.lower()
-            if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
+            if schema['col_{}'.format(col_index+1)] == 'real' and not isinstance(val, (int, float)):
                 try:
-                    val = float(parse_decimal(val))
+                    val = float(parse_decimal(val, locale='en_US'))
                 except NumberFormatError as e:
-                    val = float(num_re.findall(val)[0])
-            where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
-            where_map['col{}'.format(col_index)] = val
+                    try:
+                        val = float(num_re.findall(val)[0]) # need to understand and debug this part.
+                    except:
+                        # Although column is of number, selected one is not number. Do nothing in this case.
+                        pass
+            where_clause.append('col_{} {} :col_{}'.format(col_index+1, cond_op_dict[op], col_index+1))
+            where_map['col_{}'.format(col_index+1)] = val
         where_str = ''
         if where_clause:
-            where_str = 'WHERE ' + ' AND '.join(where_clause)
-        query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)
-        #print query
-        out = self.db.query(query, **where_map)
+            where_str = 'WHERE ' + wr.join(where_clause)
+        query = 'SELECT {} AS result FROM {} {}'.format(tmp, table_id, where_str)
+        out = self.conn.query(query, **where_map)
         return [o.result for o in out]

+ 16 - 33
sqlnet/utils.py

@@ -12,20 +12,20 @@ def load_data(sql_paths, table_paths, use_small=False):
     table_data = {}
 
     for SQL_PATH in sql_paths:
-        print "Loading data from %s" % SQL_PATH
         with open(SQL_PATH) as inf:
             for idx, line in enumerate(inf):
                 sql = json.loads(line.strip())
                 if use_small and idx >= 1000:
                     break
                 sql_data.append(sql)
+        print "Loaded %d data from %s" % (len(sql_data), SQL_PATH)
 
     for TABLE_PATH in table_paths:
-        print "Loading data from %s" % TABLE_PATH
         with open(TABLE_PATH) as inf:
             for line in inf:
                 tab = json.loads(line.strip())
                 table_data[tab[u'id']] = tab
+        print "Loaded %d data from %s" % (len(table_data), TABLE_PATH)
 
     ret_sql_data = []
     for sql in sql_data:
@@ -131,34 +131,7 @@ def epoch_train(model, optimizer, batch_size, sql_data, table_data):
         optimizer.step()
     return cum_loss / len(sql_data)
 
-def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path):
-    engine = DBEngine(db_path)
-    model.eval()
-    perm = list(range(len(sql_data)))
-    tot_acc_num = 0.0
-    for st in tqdm(range(len(sql_data)//batch_size+1)):
-        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
-        st = st * batch_size
-        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
-            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
-        raw_q_seq = [x[0] for x in raw_data]
-        raw_col_seq = [x[1] for x in raw_data]
-        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
-        gt_sel_seq = [x[2] for x in ans_seq]
-        score = model.forward(q_seq, col_seq, col_num, gt_sel=gt_sel_seq)
-        pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq, raw_col_seq)
-
-        for idx, (sql_gt, sql_pred, tid) in enumerate(zip(query_gt, pred_queries, table_ids)):
-            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'])
-            try:
-                ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
-            except:
-                ret_pred = None
-            tot_acc_num += (ret_gt == ret_pred)
-    return tot_acc_num / len(sql_data)
-
-def predict_test(model, batch_size, sql_data, table_data, db_path, output_path):
-    engine = DBEngine(db_path)
+def predict_test(model, batch_size, sql_data, table_data, output_path):
     model.eval()
     perm = list(range(len(sql_data)))
     fw = open(output_path,'w')
@@ -172,11 +145,12 @@ def predict_test(model, batch_size, sql_data, table_data, db_path, output_path):
             fw.writelines(json.dumps(sql_pred,ensure_ascii=False).encode('utf-8')+'\n')
     fw.close()
 
-def epoch_acc(model, batch_size, sql_data, table_data):
+def epoch_acc(model, batch_size, sql_data, table_data, db_path):
+    engine = DBEngine(db_path)
     model.eval()
     perm = list(range(len(sql_data)))
     badcase = 0
-    one_acc_num, tot_acc_num = 0.0, 0.0
+    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
     for st in tqdm(range(len(sql_data)//batch_size+1)):
         ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
         st = st * batch_size
@@ -203,7 +177,16 @@ def epoch_acc(model, batch_size, sql_data, table_data):
             continue
         one_acc_num += (ed-st-one_err)
         tot_acc_num += (ed-st-tot_err)
-    return one_acc_num / len(sql_data), tot_acc_num / len(sql_data),
+
+        # Execution Accuracy
+        for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
+            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op'])
+            try:
+                ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op'])
+            except:
+                ret_pred = None
+            ex_acc_num += (ret_gt == ret_pred)
+    return one_acc_num / len(sql_data), tot_acc_num / len(sql_data), ex_acc_num / len(sql_data)
 
 
 def load_word_emb(file_name):

+ 4 - 1
test.py

@@ -32,6 +32,9 @@ if __name__ == '__main__':
     model.load_state_dict(torch.load(model_path))
     print "Loaded model from %s" % model_path
 
+    dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
+    print 'Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2])
+
     print "Start to predict test set"
-    predict_test(model, batch_size, test_sql, test_table, test_db, args.output_dir)
+    predict_test(model, batch_size, test_sql, test_table, args.output_dir)
     print "Output path of prediction result is %s" % args.output_dir

+ 7 - 2
train.py

@@ -43,6 +43,7 @@ if __name__ == '__main__':
     best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv, best_wr = 0, 0, 0, 0, 0, 0, 0, 0
     best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx, best_wr_idx = 0, 0, 0, 0, 0, 0, 0, 0
     best_lf, best_lf_idx = 0.0, 0
+    best_ex, best_ex_idx = 0.0, 0
 
     print "#"*20+"  Star to Train  " + "#"*20
     for i in range(args.epoch):
@@ -50,7 +51,7 @@ if __name__ == '__main__':
         # train on the train dataset
         train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
         # evaluate on the dev dataset
-        dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table)
+        dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
         # accuracy of each sub-task
         print 'Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
             dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7])
@@ -59,6 +60,9 @@ if __name__ == '__main__':
             best_lf = dev_acc[1]
             best_lf_idx = i + 1
             torch.save(model.state_dict(), 'saved_model/best_model')
+        if dev_acc[2] > best_ex:
+            best_ex = dev_acc[2]
+            best_ex_idx = i + 1
 
         # record the best score of each sub-task
         if True:
@@ -87,8 +91,9 @@ if __name__ == '__main__':
                 best_wr = dev_acc[0][7]
                 best_wr_idx = i+1
         print 'Train loss = %.3f' % train_loss
-        print 'Dev Logic Form: %.3f' % dev_acc[1]
+        print 'Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2])
         print 'Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx)
+        print 'Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx)
         if (i+1) % 10 == 0:
             print 'Best val acc: %s\nOn epoch individually %s'%(
                     (best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),