|
@@ -12,20 +12,20 @@ def load_data(sql_paths, table_paths, use_small=False):
|
|
table_data = {}
|
|
table_data = {}
|
|
|
|
|
|
for SQL_PATH in sql_paths:
|
|
for SQL_PATH in sql_paths:
|
|
- print "Loading data from %s" % SQL_PATH
|
|
|
|
with open(SQL_PATH) as inf:
|
|
with open(SQL_PATH) as inf:
|
|
for idx, line in enumerate(inf):
|
|
for idx, line in enumerate(inf):
|
|
sql = json.loads(line.strip())
|
|
sql = json.loads(line.strip())
|
|
if use_small and idx >= 1000:
|
|
if use_small and idx >= 1000:
|
|
break
|
|
break
|
|
sql_data.append(sql)
|
|
sql_data.append(sql)
|
|
|
|
+ print "Loaded %d data from %s" % (len(sql_data), SQL_PATH)
|
|
|
|
|
|
for TABLE_PATH in table_paths:
|
|
for TABLE_PATH in table_paths:
|
|
- print "Loading data from %s" % TABLE_PATH
|
|
|
|
with open(TABLE_PATH) as inf:
|
|
with open(TABLE_PATH) as inf:
|
|
for line in inf:
|
|
for line in inf:
|
|
tab = json.loads(line.strip())
|
|
tab = json.loads(line.strip())
|
|
table_data[tab[u'id']] = tab
|
|
table_data[tab[u'id']] = tab
|
|
|
|
+ print "Loaded %d data from %s" % (len(table_data), TABLE_PATH)
|
|
|
|
|
|
ret_sql_data = []
|
|
ret_sql_data = []
|
|
for sql in sql_data:
|
|
for sql in sql_data:
|
|
@@ -131,34 +131,7 @@ def epoch_train(model, optimizer, batch_size, sql_data, table_data):
|
|
optimizer.step()
|
|
optimizer.step()
|
|
return cum_loss / len(sql_data)
|
|
return cum_loss / len(sql_data)
|
|
|
|
|
|
-def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path):
|
|
|
|
- engine = DBEngine(db_path)
|
|
|
|
- model.eval()
|
|
|
|
- perm = list(range(len(sql_data)))
|
|
|
|
- tot_acc_num = 0.0
|
|
|
|
- for st in tqdm(range(len(sql_data)//batch_size+1)):
|
|
|
|
- ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
|
|
|
|
- st = st * batch_size
|
|
|
|
- q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
|
|
|
|
- to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
|
|
|
|
- raw_q_seq = [x[0] for x in raw_data]
|
|
|
|
- raw_col_seq = [x[1] for x in raw_data]
|
|
|
|
- query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
|
|
|
|
- gt_sel_seq = [x[2] for x in ans_seq]
|
|
|
|
- score = model.forward(q_seq, col_seq, col_num, gt_sel=gt_sel_seq)
|
|
|
|
- pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq, raw_col_seq)
|
|
|
|
-
|
|
|
|
- for idx, (sql_gt, sql_pred, tid) in enumerate(zip(query_gt, pred_queries, table_ids)):
|
|
|
|
- ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'])
|
|
|
|
- try:
|
|
|
|
- ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
|
|
|
|
- except:
|
|
|
|
- ret_pred = None
|
|
|
|
- tot_acc_num += (ret_gt == ret_pred)
|
|
|
|
- return tot_acc_num / len(sql_data)
|
|
|
|
-
|
|
|
|
-def predict_test(model, batch_size, sql_data, table_data, db_path, output_path):
|
|
|
|
- engine = DBEngine(db_path)
|
|
|
|
|
|
+def predict_test(model, batch_size, sql_data, table_data, output_path):
|
|
model.eval()
|
|
model.eval()
|
|
perm = list(range(len(sql_data)))
|
|
perm = list(range(len(sql_data)))
|
|
fw = open(output_path,'w')
|
|
fw = open(output_path,'w')
|
|
@@ -172,11 +145,12 @@ def predict_test(model, batch_size, sql_data, table_data, db_path, output_path):
|
|
fw.writelines(json.dumps(sql_pred,ensure_ascii=False).encode('utf-8')+'\n')
|
|
fw.writelines(json.dumps(sql_pred,ensure_ascii=False).encode('utf-8')+'\n')
|
|
fw.close()
|
|
fw.close()
|
|
|
|
|
|
-def epoch_acc(model, batch_size, sql_data, table_data):
|
|
|
|
|
|
+def epoch_acc(model, batch_size, sql_data, table_data, db_path):
|
|
|
|
+ engine = DBEngine(db_path)
|
|
model.eval()
|
|
model.eval()
|
|
perm = list(range(len(sql_data)))
|
|
perm = list(range(len(sql_data)))
|
|
badcase = 0
|
|
badcase = 0
|
|
- one_acc_num, tot_acc_num = 0.0, 0.0
|
|
|
|
|
|
+ one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
|
|
for st in tqdm(range(len(sql_data)//batch_size+1)):
|
|
for st in tqdm(range(len(sql_data)//batch_size+1)):
|
|
ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
|
|
ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
|
|
st = st * batch_size
|
|
st = st * batch_size
|
|
@@ -203,7 +177,16 @@ def epoch_acc(model, batch_size, sql_data, table_data):
|
|
continue
|
|
continue
|
|
one_acc_num += (ed-st-one_err)
|
|
one_acc_num += (ed-st-one_err)
|
|
tot_acc_num += (ed-st-tot_err)
|
|
tot_acc_num += (ed-st-tot_err)
|
|
- return one_acc_num / len(sql_data), tot_acc_num / len(sql_data),
|
|
|
|
|
|
+
|
|
|
|
+ # Execution Accuracy
|
|
|
|
+ for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
|
|
|
|
+ ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op'])
|
|
|
|
+ try:
|
|
|
|
+ ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op'])
|
|
|
|
+ except:
|
|
|
|
+ ret_pred = None
|
|
|
|
+ ex_acc_num += (ret_gt == ret_pred)
|
|
|
|
+ return one_acc_num / len(sql_data), tot_acc_num / len(sql_data), ex_acc_num / len(sql_data)
|
|
|
|
|
|
|
|
|
|
def load_word_emb(file_name):
|
|
def load_word_emb(file_name):
|