utils.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import json
  2. from lib.dbengine import DBEngine
  3. import numpy as np
  4. from tqdm import tqdm
  5. def load_data(sql_paths, table_paths, use_small=False):
  6. if not isinstance(sql_paths, list):
  7. sql_paths = (sql_paths, )
  8. if not isinstance(table_paths, list):
  9. table_paths = (table_paths, )
  10. sql_data = []
  11. table_data = {}
  12. for SQL_PATH in sql_paths:
  13. print "Loading data from %s" % SQL_PATH
  14. with open(SQL_PATH) as inf:
  15. for idx, line in enumerate(inf):
  16. sql = json.loads(line.strip())
  17. if use_small and idx >= 1000:
  18. break
  19. sql_data.append(sql)
  20. for TABLE_PATH in table_paths:
  21. print "Loading data from %s" % TABLE_PATH
  22. with open(TABLE_PATH) as inf:
  23. for line in inf:
  24. tab = json.loads(line.strip())
  25. table_data[tab[u'id']] = tab
  26. ret_sql_data = []
  27. for sql in sql_data:
  28. if sql[u'table_id'] in table_data:
  29. ret_sql_data.append(sql)
  30. return ret_sql_data, table_data
  31. def load_dataset(toy=False, use_small=False, mode='train'):
  32. print "Loading dataset"
  33. dev_sql, dev_table = load_data('data/dev.json', 'data/dev.tables.json', use_small=use_small)
  34. dev_db = 'data/dev.db'
  35. if mode == 'train':
  36. train_sql, train_table = load_data('data/train.json', 'data/train.tables.json', use_small=use_small)
  37. train_db = 'data/train.db'
  38. return train_sql, train_table, train_db, dev_sql, dev_table, dev_db
  39. elif mode == 'test':
  40. test_sql, test_table = load_data('data/test.json', 'data/test.tables.json', use_small=use_small)
  41. test_db = 'data/test.db'
  42. return dev_sql, dev_table, dev_db, test_sql, test_table, test_db
  43. def to_batch_seq(sql_data, table_data, idxes, st, ed, ret_vis_data=False):
  44. q_seq = []
  45. col_seq = []
  46. col_num = []
  47. ans_seq = []
  48. gt_cond_seq = []
  49. vis_seq = []
  50. sel_num_seq = []
  51. for i in range(st, ed):
  52. sql = sql_data[idxes[i]]
  53. sel_num = len(sql['sql']['sel'])
  54. sel_num_seq.append(sel_num)
  55. conds_num = len(sql['sql']['conds'])
  56. q_seq.append([char for char in sql['question']])
  57. col_seq.append([[char for char in header] for header in table_data[sql['table_id']]['header']])
  58. col_num.append(len(table_data[sql['table_id']]['header']))
  59. ans_seq.append(
  60. (
  61. len(sql['sql']['agg']),
  62. sql['sql']['sel'],
  63. sql['sql']['agg'],
  64. conds_num,
  65. tuple(x[0] for x in sql['sql']['conds']),
  66. tuple(x[1] for x in sql['sql']['conds']),
  67. sql['sql']['cond_conn_op'],
  68. ))
  69. gt_cond_seq.append(sql['sql']['conds'])
  70. vis_seq.append((sql['question'], table_data[sql['table_id']]['header']))
  71. if ret_vis_data:
  72. return q_seq, sel_num_seq, col_seq, col_num, ans_seq, gt_cond_seq, vis_seq
  73. else:
  74. return q_seq, sel_num_seq, col_seq, col_num, ans_seq, gt_cond_seq
  75. def to_batch_seq_test(sql_data, table_data, idxes, st, ed):
  76. q_seq = []
  77. col_seq = []
  78. col_num = []
  79. raw_seq = []
  80. table_ids = []
  81. for i in range(st, ed):
  82. sql = sql_data[idxes[i]]
  83. q_seq.append([char for char in sql['question']])
  84. col_seq.append([[char for char in header] for header in table_data[sql['table_id']]['header']])
  85. col_num.append(len(table_data[sql['table_id']]['header']))
  86. raw_seq.append(sql['question'])
  87. table_ids.append(sql_data[idxes[i]]['table_id'])
  88. return q_seq, col_seq, col_num, raw_seq, table_ids
  89. def to_batch_query(sql_data, idxes, st, ed):
  90. query_gt = []
  91. table_ids = []
  92. for i in range(st, ed):
  93. sql_data[idxes[i]]['sql']['conds'] = sql_data[idxes[i]]['sql']['conds']
  94. query_gt.append(sql_data[idxes[i]]['sql'])
  95. table_ids.append(sql_data[idxes[i]]['table_id'])
  96. return query_gt, table_ids
  97. def epoch_train(model, optimizer, batch_size, sql_data, table_data):
  98. model.train()
  99. perm=np.random.permutation(len(sql_data))
  100. cum_loss = 0.0
  101. for st in tqdm(range(len(sql_data)//batch_size+1)):
  102. ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
  103. st = st * batch_size
  104. q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq = to_batch_seq(sql_data, table_data, perm, st, ed)
  105. # q_seq: char-based sequence of question
  106. # gt_sel_num: number of selected columns and aggregation functions
  107. # col_seq: char-based column name
  108. # col_num: number of headers in one table
  109. # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
  110. # gt_cond_seq: ground truth of conds
  111. gt_where_seq = model.generate_gt_where_seq_test(q_seq, gt_cond_seq)
  112. gt_sel_seq = [x[1] for x in ans_seq]
  113. score = model.forward(q_seq, col_seq, col_num, gt_where=gt_where_seq, gt_cond=gt_cond_seq, gt_sel=gt_sel_seq, gt_sel_num=gt_sel_num)
  114. # sel_num_score, sel_col_score, sel_agg_score, cond_score, cond_rela_score
  115. # compute loss
  116. loss = model.loss(score, ans_seq, gt_where_seq)
  117. cum_loss += loss.data.cpu().numpy()*(ed - st)
  118. optimizer.zero_grad()
  119. loss.backward()
  120. optimizer.step()
  121. return cum_loss / len(sql_data)
  122. def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path):
  123. engine = DBEngine(db_path)
  124. model.eval()
  125. perm = list(range(len(sql_data)))
  126. tot_acc_num = 0.0
  127. for st in tqdm(range(len(sql_data)//batch_size+1)):
  128. ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
  129. st = st * batch_size
  130. q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
  131. to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
  132. raw_q_seq = [x[0] for x in raw_data]
  133. raw_col_seq = [x[1] for x in raw_data]
  134. query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
  135. gt_sel_seq = [x[2] for x in ans_seq]
  136. score = model.forward(q_seq, col_seq, col_num, gt_sel=gt_sel_seq)
  137. pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq, raw_col_seq)
  138. for idx, (sql_gt, sql_pred, tid) in enumerate(zip(query_gt, pred_queries, table_ids)):
  139. ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'])
  140. try:
  141. ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
  142. except:
  143. ret_pred = None
  144. tot_acc_num += (ret_gt == ret_pred)
  145. return tot_acc_num / len(sql_data)
  146. def predict_test(model, batch_size, sql_data, table_data, db_path, output_path):
  147. engine = DBEngine(db_path)
  148. model.eval()
  149. perm = list(range(len(sql_data)))
  150. fw = open(output_path,'w')
  151. for st in tqdm(range(len(sql_data)//batch_size+1)):
  152. ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
  153. st = st * batch_size
  154. q_seq, col_seq, col_num, raw_q_seq, table_ids = to_batch_seq_test(sql_data, table_data, perm, st, ed)
  155. score = model.forward(q_seq, col_seq, col_num)
  156. sql_preds = model.gen_query(score, q_seq, col_seq, raw_q_seq)
  157. for sql_pred in sql_preds:
  158. fw.writelines(json.dumps(sql_pred,ensure_ascii=False).encode('utf-8')+'\n')
  159. fw.close()
  160. def epoch_acc(model, batch_size, sql_data, table_data):
  161. model.eval()
  162. perm = list(range(len(sql_data)))
  163. badcase = 0
  164. one_acc_num, tot_acc_num = 0.0, 0.0
  165. for st in tqdm(range(len(sql_data)//batch_size+1)):
  166. ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
  167. st = st * batch_size
  168. q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
  169. to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
  170. # q_seq: char-based sequence of question
  171. # gt_sel_num: number of selected columns and aggregation functions, new added field
  172. # col_seq: char-based column name
  173. # col_num: number of headers in one table
  174. # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
  175. # gt_cond_seq: ground truth of conditions
  176. # raw_data: ori question, headers, sql
  177. query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
  178. # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value}
  179. raw_q_seq = [x[0] for x in raw_data] # original question
  180. try:
  181. score = model.forward(q_seq, col_seq, col_num)
  182. pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq)
  183. # generate predicted format
  184. one_err, tot_err = model.check_acc(raw_data, pred_queries, query_gt)
  185. except:
  186. badcase += 1
  187. print 'badcase', badcase
  188. continue
  189. one_acc_num += (ed-st-one_err)
  190. tot_acc_num += (ed-st-tot_err)
  191. return one_acc_num / len(sql_data), tot_acc_num / len(sql_data),
  192. def load_word_emb(file_name):
  193. print ('Loading word embedding from %s'%file_name)
  194. ret = {}
  195. with open(file_name) as inf:
  196. for idx, line in enumerate(inf):
  197. info = line.strip().split(' ')
  198. if info[0].lower() not in ret:
  199. ret[info[0].decode('utf-8')] = np.array(map(lambda x:float(x), info[1:]))
  200. return ret