liuyuqi 4 years ago
parent
commit
94e8a3833c
9 changed files with 846 additions and 764 deletions
  1. 2 0
      .gitignore
  2. 5 0
      .vscode/settings.json
  3. 16 0
      README.cn.md
  4. 18 0
      preview.py
  5. 6 0
      requirements.txt
  6. 419 419
      sqlnet/model/sqlnet.py
  7. 227 205
      sqlnet/utils.py
  8. 40 40
      test.py
  9. 113 100
      train.py

+ 2 - 0
.gitignore

@@ -0,0 +1,2 @@
+/data
+*.pyc

+ 5 - 0
.vscode/settings.json

@@ -0,0 +1,5 @@
+{
+    "python.pythonPath": "/opt/anaconda3/envs/py36-ai/bin/python",
+    "python.linting.pylintEnabled": true,
+    "python.linting.enabled": true
+}

+ 16 - 0
README.cn.md

@@ -0,0 +1,16 @@
+## nl2sql 文字问题转 sql
+
+
+https://drive.google.com/open?id=10RtAom_D4zOp_w5OsYLTtLd2TEuC7C1X
+
+模型训练所需要的算力还是挺大的,在 K80 上训练一个 Epoch 大概需要 7 到 8 分钟,基线模型默认训练 100 个 Epoch。追一表示试验多次后,验证集上的平均 Logic Form Accuracy 为 27.5%,即 SQL 每一个子句全都预测正确的概率为 27.5%。
+
+https://mp.weixin.qq.com/s?__biz=MzA3MzI4MjgzMw==&mid=2650764672&idx=2&sn=bb8d8bb191bfb4b4093e874193019724&chksm=871ab1feb06d38e8a06dea5eb37bec8b386ebce5da23864a8b2d114bba96fc18521c531ab357&scene=0&xtrack=1#rd
+
+
+WikiSQL 基准:https://github.com/salesforce/WikiSQL 
+
+```
+CUDA_VISIBLE_DEVICES=0 python train.py --ca --gpu --bs 128
+```
+

+ 18 - 0
preview.py

@@ -0,0 +1,18 @@
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+'''
+@File    :   preview.py
+@Time    :   2019/06/25 11:24:20
+@Author  :   Liuyuqi 
+@Version :   1.0
+@Contact :   liuyuqi.gov@msn.cn
+@License :   (C)Copyright 2019
+@Desc    :   数据预览
+'''
+from sqlnet.utils import *
+from sqlnet.model.sqlnet import SQLNet
+
+if __name__ == "__main__":
+    train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(
+            use_small=True)
+    print(train_sql)

+ 6 - 0
requirements.txt

@@ -0,0 +1,6 @@
+Babel==2.7.0
+numpy==1.13.1
+pylint==2.3.1
+records==0.5.3
+torch==1.1.0
+tqdm==4.32.2

+ 419 - 419
sqlnet/model/sqlnet.py

@@ -1,419 +1,419 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.autograd import Variable
-import numpy as np
-from sqlnet.model.modules.word_embedding import WordEmbedding
-from sqlnet.model.modules.aggregator_predict import AggPredictor
-from sqlnet.model.modules.selection_predict import SelPredictor
-from sqlnet.model.modules.sqlnet_condition_predict import SQLNetCondPredictor
-from sqlnet.model.modules.select_number import SelNumPredictor
-from sqlnet.model.modules.where_relation import WhereRelationPredictor
-
-
-class SQLNet(nn.Module):
-    def __init__(self, word_emb, N_word, N_h=100, N_depth=2,
-            gpu=False, use_ca=True, trainable_emb=False):
-        super(SQLNet, self).__init__()
-        self.use_ca = use_ca
-        self.trainable_emb = trainable_emb
-
-        self.gpu = gpu
-        self.N_h = N_h
-        self.N_depth = N_depth
-
-        self.max_col_num = 45
-        self.max_tok_num = 200
-        self.SQL_TOK = ['<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=', '<BEG>']
-        self.COND_OPS = ['>', '<', '==', '!=']
-
-        # Word embedding
-        self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb)
-
-        # Predict the number of selected columns
-        self.sel_num = SelNumPredictor(N_word, N_h, N_depth, use_ca=use_ca)
-
-        #Predict which columns are selected
-        self.sel_pred = SelPredictor(N_word, N_h, N_depth, self.max_tok_num, use_ca=use_ca)
-
-        #Predict aggregation functions of corresponding selected columns
-        self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca)
-
-        #Predict number of conditions, condition columns, condition operations and condition values
-        self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth, self.max_col_num, self.max_tok_num, use_ca, gpu)
-
-        # Predict condition relationship, like 'and', 'or'
-        self.where_rela_pred = WhereRelationPredictor(N_word, N_h, N_depth, use_ca=use_ca)
-
-
-        self.CE = nn.CrossEntropyLoss()
-        self.softmax = nn.Softmax(dim=-1)
-        self.log_softmax = nn.LogSoftmax()
-        self.bce_logit = nn.BCEWithLogitsLoss()
-        if gpu:
-            self.cuda()
-
-    def generate_gt_where_seq_test(self, q, gt_cond_seq):
-        ret_seq = []
-        for cur_q, ans in zip(q, gt_cond_seq):
-            temp_q = u"".join(cur_q)
-            cur_q = [u'<BEG>'] + cur_q + [u'<END>']
-            record = []
-            record_cond = []
-            for cond in ans:
-                if cond[2] not in temp_q:
-                    record.append((False, cond[2]))
-                else:
-                    record.append((True, cond[2]))
-            for idx, item in enumerate(record):
-                temp_ret_seq = []
-                if item[0]:
-                    temp_ret_seq.append(0)
-                    temp_ret_seq.extend(list(range(temp_q.index(item[1])+1,temp_q.index(item[1])+len(item[1])+1)))
-                    temp_ret_seq.append(len(cur_q)-1)
-                else:
-                    temp_ret_seq.append([0,len(cur_q)-1])
-                record_cond.append(temp_ret_seq)
-            ret_seq.append(record_cond)
-        return ret_seq
-
-    def forward(self, q, col, col_num, gt_where = None, gt_cond=None, reinforce=False, gt_sel=None, gt_sel_num=None):
-        B = len(q)
-
-        sel_num_score = None
-        agg_score = None
-        sel_score = None
-        cond_score = None
-        #Predict aggregator
-        if self.trainable_emb:
-            x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
-            col_inp_var, col_name_len, col_len = self.agg_embed_layer.gen_col_batch(col)
-            max_x_len = max(x_len)
-            agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var,
-                    col_name_len, col_len, col_num, gt_sel=gt_sel)
-
-            x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
-            col_inp_var, col_name_len, col_len = self.sel_embed_layer.gen_col_batch(col)
-            max_x_len = max(x_len)
-            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
-                    col_name_len, col_len, col_num)
-
-            x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
-            col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch(col)
-            max_x_len = max(x_len)
-            cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce)
-            where_rela_score = None
-        else:
-            x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col)
-            col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(col)
-            sel_num_score = self.sel_num(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
-            # x_emb_var: embedding of each question
-            # x_len: length of each question
-            # col_inp_var: embedding of each header
-            # col_name_len: length of each header
-            # col_len: number of headers in each table, array type
-            # col_num: number of headers in each table, list type
-            if gt_sel_num:
-                pr_sel_num = gt_sel_num
-            else:
-                pr_sel_num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
-            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
-
-            if gt_sel:
-                pr_sel = gt_sel
-            else:
-                num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
-                sel = sel_score.data.cpu().numpy()
-                pr_sel = [list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num))]
-            agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_sel=pr_sel, gt_sel_num=pr_sel_num)
-
-            where_rela_score = self.where_rela_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
-
-            cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce)
-
-        return (sel_num_score, sel_score, agg_score, cond_score, where_rela_score)
-
-    def loss(self, score, truth_num, gt_where):
-        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
-
-        B = len(truth_num)
-        loss = 0
-
-        # Evaluate select number
-        # sel_num_truth = map(lambda x:x[0], truth_num)
-        sel_num_truth = [x[0] for x in truth_num]
-        sel_num_truth = torch.from_numpy(np.array(sel_num_truth))
-        if self.gpu:
-            sel_num_truth = Variable(sel_num_truth.cuda())
-        else:
-            sel_num_truth = Variable(sel_num_truth)
-        loss += self.CE(sel_num_score, sel_num_truth)
-
-        # Evaluate select column
-        T = len(sel_score[0])
-        truth_prob = np.zeros((B,T), dtype=np.float32)
-        for b in range(B):
-            truth_prob[b][list(truth_num[b][1])] = 1
-        data = torch.from_numpy(truth_prob)
-        if self.gpu:
-            sel_col_truth_var = Variable(data.cuda())
-        else:
-            sel_col_truth_var = Variable(data)
-        sigm = nn.Sigmoid()
-        sel_col_prob = sigm(sel_score)
-        bce_loss = -torch.mean(
-            3*(sel_col_truth_var * torch.log(sel_col_prob+1e-10)) +
-            (1-sel_col_truth_var) * torch.log(1-sel_col_prob+1e-10)
-        )
-        loss += bce_loss
-
-        # Evaluate select aggregation
-        for b in range(len(truth_num)):
-            data = torch.from_numpy(np.array(truth_num[b][2]))
-            if self.gpu:
-                sel_agg_truth_var = Variable(data.cuda())
-            else:
-                sel_agg_truth_var = Variable(data)
-            sel_agg_pred = agg_score[b, :len(truth_num[b][1])]
-            loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)
-
-        cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score
-
-        # Evaluate the number of conditions
-        # cond_num_truth = map(lambda x:x[3], truth_num)
-        cond_num_truth = [x[3] for x in truth_num]
-        data = torch.from_numpy(np.array(cond_num_truth))
-        if self.gpu:
-            try:
-                cond_num_truth_var = Variable(data.cuda())
-            except:
-                print ("cond_num_truth_var error")
-                print (data)
-                exit(0)
-        else:
-            cond_num_truth_var = Variable(data)
-        loss += self.CE(cond_num_score, cond_num_truth_var)
-
-        # Evaluate the columns of conditions
-        T = len(cond_col_score[0])
-        truth_prob = np.zeros((B, T), dtype=np.float32)
-        for b in range(B):
-            if len(truth_num[b][4]) > 0:
-                truth_prob[b][list(truth_num[b][4])] = 1
-        data = torch.from_numpy(truth_prob)
-        if self.gpu:
-            cond_col_truth_var = Variable(data.cuda())
-        else:
-            cond_col_truth_var = Variable(data)
-
-        sigm = nn.Sigmoid()
-        cond_col_prob = sigm(cond_col_score)
-        bce_loss = -torch.mean(
-            3*(cond_col_truth_var * torch.log(cond_col_prob+1e-10)) +
-            (1-cond_col_truth_var) * torch.log(1-cond_col_prob+1e-10) )
-        loss += bce_loss
-
-        # Evaluate the operator of conditions
-        for b in range(len(truth_num)):
-            if len(truth_num[b][5]) == 0:
-                continue
-            data = torch.from_numpy(np.array(truth_num[b][5]))
-            if self.gpu:
-                cond_op_truth_var = Variable(data.cuda())
-            else:
-                cond_op_truth_var = Variable(data)
-            cond_op_pred = cond_op_score[b, :len(truth_num[b][5])]
-            try:
-                loss += (self.CE(cond_op_pred, cond_op_truth_var) / len(truth_num))
-            except:
-                print (cond_op_pred)
-                print (cond_op_truth_var)
-                exit(0)
-
-        #Evaluate the strings of conditions
-        for b in range(len(gt_where)):
-            for idx in range(len(gt_where[b])):
-                cond_str_truth = gt_where[b][idx]
-                if len(cond_str_truth) == 1:
-                    continue
-                data = torch.from_numpy(np.array(cond_str_truth[1:]))
-                if self.gpu:
-                    cond_str_truth_var = Variable(data.cuda())
-                else:
-                    cond_str_truth_var = Variable(data)
-                str_end = len(cond_str_truth)-1
-                cond_str_pred = cond_str_score[b, idx, :str_end]
-                loss += (self.CE(cond_str_pred, cond_str_truth_var) \
-                        / (len(gt_where) * len(gt_where[b])))
-
-        # Evaluate condition relationship, and / or
-        # where_rela_truth = map(lambda x:x[6], truth_num)
-        where_rela_truth = [x[6] for x in truth_num]
-        data = torch.from_numpy(np.array(where_rela_truth))
-        if self.gpu:
-            try:
-                where_rela_truth = Variable(data.cuda())
-            except:
-                print ("where_rela_truth error")
-                print (data)
-                exit(0)
-        else:
-            where_rela_truth = Variable(data)
-        loss += self.CE(where_rela_score, where_rela_truth)
-        return loss
-
-    def check_acc(self, vis_info, pred_queries, gt_queries):
-        def gen_cond_str(conds, header):
-            if len(conds) == 0:
-                return 'None'
-            cond_str = []
-            for cond in conds:
-                cond_str.append(header[cond[0]] + ' ' +
-                    self.COND_OPS[cond[1]] + ' ' + unicode(cond[2]).lower())
-            return 'WHERE ' + ' AND '.join(cond_str)
-
-        tot_err = sel_num_err = agg_err = sel_err = 0.0
-        cond_num_err = cond_col_err = cond_op_err = cond_val_err = cond_rela_err = 0.0
-        for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
-            good = True
-            sel_pred, agg_pred, where_rela_pred = pred_qry['sel'], pred_qry['agg'], pred_qry['cond_conn_op']
-            sel_gt, agg_gt, where_rela_gt = gt_qry['sel'], gt_qry['agg'], gt_qry['cond_conn_op']
-
-            if where_rela_gt != where_rela_pred:
-                good = False
-                cond_rela_err += 1
-
-            if len(sel_pred) != len(sel_gt):
-                good = False
-                sel_num_err += 1
-
-            pred_sel_dict = {k:v for k,v in zip(list(sel_pred), list(agg_pred))}
-            gt_sel_dict = {k:v for k,v in zip(sel_gt, agg_gt)}
-            if set(sel_pred) != set(sel_gt):
-                good = False
-                sel_err += 1
-            agg_pred = [pred_sel_dict[x] for x in sorted(pred_sel_dict.keys())]
-            agg_gt = [gt_sel_dict[x] for x in sorted(gt_sel_dict.keys())]
-            if agg_pred != agg_gt:
-                good = False
-                agg_err += 1
-
-            cond_pred = pred_qry['conds']
-            cond_gt = gt_qry['conds']
-            if len(cond_pred) != len(cond_gt):
-                good = False
-                cond_num_err += 1
-            else:
-                cond_op_pred, cond_op_gt = {}, {}
-                cond_val_pred, cond_val_gt = {}, {}
-                for p, g in zip(cond_pred, cond_gt):
-                    cond_op_pred[p[0]] = p[1]
-                    cond_val_pred[p[0]] = p[2]
-                    cond_op_gt[g[0]] = g[1]
-                    cond_val_gt[g[0]] = g[2]
-
-                if set(cond_op_pred.keys()) != set(cond_op_gt.keys()):
-                    cond_col_err += 1
-                    good=False
-
-                where_op_pred = [cond_op_pred[x] for x in sorted(cond_op_pred.keys())]
-                where_op_gt = [cond_op_gt[x] for x in sorted(cond_op_gt.keys())]
-                if where_op_pred != where_op_gt:
-                    cond_op_err += 1
-                    good=False
-
-                where_val_pred = [cond_val_pred[x] for x in sorted(cond_val_pred.keys())]
-                where_val_gt = [cond_val_gt[x] for x in sorted(cond_val_gt.keys())]
-                if where_val_pred != where_val_gt:
-                    cond_val_err += 1
-                    good=False
-
-            if not good:
-                tot_err += 1
-
-        return np.array((sel_num_err, sel_err, agg_err, cond_num_err, cond_col_err, cond_op_err, cond_val_err , cond_rela_err)), tot_err
-
-
-    def gen_query(self, score, q, col, raw_q, reinforce=False, verbose=False):
-        """
-        :param score:
-        :param q: token-questions
-        :param col: token-headers
-        :param raw_q: original question sequence
-        :return:
-        """
-        def merge_tokens(tok_list, raw_tok_str):
-            tok_str = raw_tok_str# .lower()
-            alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
-            special = {'-LRB-':'(',
-                    '-RRB-':')',
-                    '-LSB-':'[',
-                    '-RSB-':']',
-                    '``':'"',
-                    '\'\'':'"',
-                    '--':u'\u2013'}
-            ret = ''
-            double_quote_appear = 0
-            for raw_tok in tok_list:
-                if not raw_tok:
-                    continue
-                tok = special.get(raw_tok, raw_tok)
-                if tok == '"':
-                    double_quote_appear = 1 - double_quote_appear
-                if len(ret) == 0:
-                    pass
-                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
-                    ret = ret + ' '
-                elif len(ret) > 0 and ret + tok in tok_str:
-                    pass
-                elif tok == '"':
-                    if double_quote_appear:
-                        ret = ret + ' '
-                # elif tok[0] not in alphabet:
-                #     pass
-                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \
-                        and (ret[-1] != '"' or not double_quote_appear):
-                    ret = ret + ' '
-                ret = ret + tok
-            return ret.strip()
-
-        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
-        # [64,4,6], [64,14], ..., [64,4]
-        sel_num_score = sel_num_score.data.cpu().numpy()
-        sel_score = sel_score.data.cpu().numpy()
-        agg_score = agg_score.data.cpu().numpy()
-        where_rela_score = where_rela_score.data.cpu().numpy()
-        ret_queries = []
-        B = len(agg_score)
-        cond_num_score,cond_col_score,cond_op_score,cond_str_score =\
-            [x.data.cpu().numpy() for x in cond_score]
-        for b in range(B):
-            cur_query = {}
-            cur_query['sel'] = []
-            cur_query['agg'] = []
-            sel_num = np.argmax(sel_num_score[b])
-            max_col_idxes = np.argsort(-sel_score[b])[:sel_num]
-            # find the most-probable columns' indexes
-            max_agg_idxes = np.argsort(-agg_score[b])[:sel_num]
-            cur_query['sel'].extend([int(i) for i in max_col_idxes])
-            cur_query['agg'].extend([i[0] for i in max_agg_idxes])
-            cur_query['cond_conn_op'] = np.argmax(where_rela_score[b])
-            cur_query['conds'] = []
-            cond_num = np.argmax(cond_num_score[b])
-            all_toks = ['<BEG>'] + q[b] + ['<END>']
-            max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
-            for idx in range(cond_num):
-                cur_cond = []
-                cur_cond.append(max_idxes[idx]) # where-col
-                cur_cond.append(np.argmax(cond_op_score[b][idx])) # where-op
-                cur_cond_str_toks = []
-                for str_score in cond_str_score[b][idx]:
-                    str_tok = np.argmax(str_score[:len(all_toks)])
-                    str_val = all_toks[str_tok]
-                    if str_val == '<END>':
-                        break
-                    cur_cond_str_toks.append(str_val)
-                cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b]))
-                cur_query['conds'].append(cur_cond)
-            ret_queries.append(cur_query)
-        return ret_queries
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from sqlnet.model.modules.word_embedding import WordEmbedding
+from sqlnet.model.modules.aggregator_predict import AggPredictor
+from sqlnet.model.modules.selection_predict import SelPredictor
+from sqlnet.model.modules.sqlnet_condition_predict import SQLNetCondPredictor
+from sqlnet.model.modules.select_number import SelNumPredictor
+from sqlnet.model.modules.where_relation import WhereRelationPredictor
+
+# 定义SQLNet模型
+class SQLNet(nn.Module):
+    def __init__(self, word_emb, N_word, N_h=100, N_depth=2,
+            gpu=False, use_ca=True, trainable_emb=False):
+        super(SQLNet, self).__init__()
+        self.use_ca = use_ca
+        self.trainable_emb = trainable_emb
+
+        self.gpu = gpu
+        self.N_h = N_h
+        self.N_depth = N_depth
+
+        self.max_col_num = 45
+        self.max_tok_num = 200
+        self.SQL_TOK = ['<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=', '<BEG>']
+        self.COND_OPS = ['>', '<', '==', '!=']
+
+        # Word embedding
+        self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb)
+
+        # Predict the number of selected columns
+        self.sel_num = SelNumPredictor(N_word, N_h, N_depth, use_ca=use_ca)
+
+        #Predict which columns are selected
+        self.sel_pred = SelPredictor(N_word, N_h, N_depth, self.max_tok_num, use_ca=use_ca)
+
+        #Predict aggregation functions of corresponding selected columns
+        self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca)
+
+        #Predict number of conditions, condition columns, condition operations and condition values
+        self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth, self.max_col_num, self.max_tok_num, use_ca, gpu)
+
+        # Predict condition relationship, like 'and', 'or'
+        self.where_rela_pred = WhereRelationPredictor(N_word, N_h, N_depth, use_ca=use_ca)
+
+
+        self.CE = nn.CrossEntropyLoss()
+        self.softmax = nn.Softmax(dim=-1)
+        self.log_softmax = nn.LogSoftmax()
+        self.bce_logit = nn.BCEWithLogitsLoss()
+        if gpu:
+            self.cuda()
+
+    def generate_gt_where_seq_test(self, q, gt_cond_seq):
+        ret_seq = []
+        for cur_q, ans in zip(q, gt_cond_seq):
+            temp_q = u"".join(cur_q)
+            cur_q = [u'<BEG>'] + cur_q + [u'<END>']
+            record = []
+            record_cond = []
+            for cond in ans:
+                if cond[2] not in temp_q:
+                    record.append((False, cond[2]))
+                else:
+                    record.append((True, cond[2]))
+            for idx, item in enumerate(record):
+                temp_ret_seq = []
+                if item[0]:
+                    temp_ret_seq.append(0)
+                    temp_ret_seq.extend(list(range(temp_q.index(item[1])+1,temp_q.index(item[1])+len(item[1])+1)))
+                    temp_ret_seq.append(len(cur_q)-1)
+                else:
+                    temp_ret_seq.append([0,len(cur_q)-1])
+                record_cond.append(temp_ret_seq)
+            ret_seq.append(record_cond)
+        return ret_seq
+
+    def forward(self, q, col, col_num, gt_where = None, gt_cond=None, reinforce=False, gt_sel=None, gt_sel_num=None):
+        B = len(q)
+
+        sel_num_score = None
+        agg_score = None
+        sel_score = None
+        cond_score = None
+        #Predict aggregator
+        if self.trainable_emb:
+            x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
+            col_inp_var, col_name_len, col_len = self.agg_embed_layer.gen_col_batch(col)
+            max_x_len = max(x_len)
+            agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var,
+                    col_name_len, col_len, col_num, gt_sel=gt_sel)
+
+            x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
+            col_inp_var, col_name_len, col_len = self.sel_embed_layer.gen_col_batch(col)
+            max_x_len = max(x_len)
+            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
+                    col_name_len, col_len, col_num)
+
+            x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
+            col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch(col)
+            max_x_len = max(x_len)
+            cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce)
+            where_rela_score = None
+        else:
+            x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col)
+            col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(col)
+            sel_num_score = self.sel_num(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
+            # x_emb_var: embedding of each question
+            # x_len: length of each question
+            # col_inp_var: embedding of each header
+            # col_name_len: length of each header
+            # col_len: number of headers in each table, array type
+            # col_num: number of headers in each table, list type
+            if gt_sel_num:
+                pr_sel_num = gt_sel_num
+            else:
+                pr_sel_num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
+            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
+
+            if gt_sel:
+                pr_sel = gt_sel
+            else:
+                num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
+                sel = sel_score.data.cpu().numpy()
+                pr_sel = [list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num))]
+            agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_sel=pr_sel, gt_sel_num=pr_sel_num)
+
+            where_rela_score = self.where_rela_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
+
+            cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce)
+
+        return (sel_num_score, sel_score, agg_score, cond_score, where_rela_score)
+
+    def loss(self, score, truth_num, gt_where):
+        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
+
+        B = len(truth_num)
+        loss = 0
+
+        # Evaluate select number
+        # sel_num_truth = map(lambda x:x[0], truth_num)
+        sel_num_truth = [x[0] for x in truth_num]
+        sel_num_truth = torch.from_numpy(np.array(sel_num_truth))
+        if self.gpu:
+            sel_num_truth = Variable(sel_num_truth.cuda())
+        else:
+            sel_num_truth = Variable(sel_num_truth)
+        loss += self.CE(sel_num_score, sel_num_truth)
+
+        # Evaluate select column
+        T = len(sel_score[0])
+        truth_prob = np.zeros((B,T), dtype=np.float32)
+        for b in range(B):
+            truth_prob[b][list(truth_num[b][1])] = 1
+        data = torch.from_numpy(truth_prob)
+        if self.gpu:
+            sel_col_truth_var = Variable(data.cuda())
+        else:
+            sel_col_truth_var = Variable(data)
+        sigm = nn.Sigmoid()
+        sel_col_prob = sigm(sel_score)
+        bce_loss = -torch.mean(
+            3*(sel_col_truth_var * torch.log(sel_col_prob+1e-10)) +
+            (1-sel_col_truth_var) * torch.log(1-sel_col_prob+1e-10)
+        )
+        loss += bce_loss
+
+        # Evaluate select aggregation
+        for b in range(len(truth_num)):
+            data = torch.from_numpy(np.array(truth_num[b][2]))
+            if self.gpu:
+                sel_agg_truth_var = Variable(data.cuda())
+            else:
+                sel_agg_truth_var = Variable(data)
+            sel_agg_pred = agg_score[b, :len(truth_num[b][1])]
+            loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)
+
+        cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score
+
+        # Evaluate the number of conditions
+        # cond_num_truth = map(lambda x:x[3], truth_num)
+        cond_num_truth = [x[3] for x in truth_num]
+        data = torch.from_numpy(np.array(cond_num_truth))
+        if self.gpu:
+            try:
+                cond_num_truth_var = Variable(data.cuda())
+            except:
+                print ("cond_num_truth_var error")
+                print (data)
+                exit(0)
+        else:
+            cond_num_truth_var = Variable(data)
+        loss += self.CE(cond_num_score, cond_num_truth_var)
+
+        # Evaluate the columns of conditions
+        T = len(cond_col_score[0])
+        truth_prob = np.zeros((B, T), dtype=np.float32)
+        for b in range(B):
+            if len(truth_num[b][4]) > 0:
+                truth_prob[b][list(truth_num[b][4])] = 1
+        data = torch.from_numpy(truth_prob)
+        if self.gpu:
+            cond_col_truth_var = Variable(data.cuda())
+        else:
+            cond_col_truth_var = Variable(data)
+
+        sigm = nn.Sigmoid()
+        cond_col_prob = sigm(cond_col_score)
+        bce_loss = -torch.mean(
+            3*(cond_col_truth_var * torch.log(cond_col_prob+1e-10)) +
+            (1-cond_col_truth_var) * torch.log(1-cond_col_prob+1e-10) )
+        loss += bce_loss
+
+        # Evaluate the operator of conditions
+        for b in range(len(truth_num)):
+            if len(truth_num[b][5]) == 0:
+                continue
+            data = torch.from_numpy(np.array(truth_num[b][5]))
+            if self.gpu:
+                cond_op_truth_var = Variable(data.cuda())
+            else:
+                cond_op_truth_var = Variable(data)
+            cond_op_pred = cond_op_score[b, :len(truth_num[b][5])]
+            try:
+                loss += (self.CE(cond_op_pred, cond_op_truth_var) / len(truth_num))
+            except:
+                print (cond_op_pred)
+                print (cond_op_truth_var)
+                exit(0)
+
+        #Evaluate the strings of conditions
+        for b in range(len(gt_where)):
+            for idx in range(len(gt_where[b])):
+                cond_str_truth = gt_where[b][idx]
+                if len(cond_str_truth) == 1:
+                    continue
+                data = torch.from_numpy(np.array(cond_str_truth[1:]))
+                if self.gpu:
+                    cond_str_truth_var = Variable(data.cuda())
+                else:
+                    cond_str_truth_var = Variable(data)
+                str_end = len(cond_str_truth)-1
+                cond_str_pred = cond_str_score[b, idx, :str_end]
+                loss += (self.CE(cond_str_pred, cond_str_truth_var) \
+                        / (len(gt_where) * len(gt_where[b])))
+
+        # Evaluate condition relationship, and / or
+        # where_rela_truth = map(lambda x:x[6], truth_num)
+        where_rela_truth = [x[6] for x in truth_num]
+        data = torch.from_numpy(np.array(where_rela_truth))
+        if self.gpu:
+            try:
+                where_rela_truth = Variable(data.cuda())
+            except:
+                print ("where_rela_truth error")
+                print (data)
+                exit(0)
+        else:
+            where_rela_truth = Variable(data)
+        loss += self.CE(where_rela_score, where_rela_truth)
+        return loss
+
+    def check_acc(self, vis_info, pred_queries, gt_queries):
+        def gen_cond_str(conds, header):
+            if len(conds) == 0:
+                return 'None'
+            cond_str = []
+            for cond in conds:
+                cond_str.append(header[cond[0]] + ' ' +
+                    self.COND_OPS[cond[1]] + ' ' + unicode(cond[2]).lower())
+            return 'WHERE ' + ' AND '.join(cond_str)
+
+        tot_err = sel_num_err = agg_err = sel_err = 0.0
+        cond_num_err = cond_col_err = cond_op_err = cond_val_err = cond_rela_err = 0.0
+        for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
+            good = True
+            sel_pred, agg_pred, where_rela_pred = pred_qry['sel'], pred_qry['agg'], pred_qry['cond_conn_op']
+            sel_gt, agg_gt, where_rela_gt = gt_qry['sel'], gt_qry['agg'], gt_qry['cond_conn_op']
+
+            if where_rela_gt != where_rela_pred:
+                good = False
+                cond_rela_err += 1
+
+            if len(sel_pred) != len(sel_gt):
+                good = False
+                sel_num_err += 1
+
+            pred_sel_dict = {k:v for k,v in zip(list(sel_pred), list(agg_pred))}
+            gt_sel_dict = {k:v for k,v in zip(sel_gt, agg_gt)}
+            if set(sel_pred) != set(sel_gt):
+                good = False
+                sel_err += 1
+            agg_pred = [pred_sel_dict[x] for x in sorted(pred_sel_dict.keys())]
+            agg_gt = [gt_sel_dict[x] for x in sorted(gt_sel_dict.keys())]
+            if agg_pred != agg_gt:
+                good = False
+                agg_err += 1
+
+            cond_pred = pred_qry['conds']
+            cond_gt = gt_qry['conds']
+            if len(cond_pred) != len(cond_gt):
+                good = False
+                cond_num_err += 1
+            else:
+                cond_op_pred, cond_op_gt = {}, {}
+                cond_val_pred, cond_val_gt = {}, {}
+                for p, g in zip(cond_pred, cond_gt):
+                    cond_op_pred[p[0]] = p[1]
+                    cond_val_pred[p[0]] = p[2]
+                    cond_op_gt[g[0]] = g[1]
+                    cond_val_gt[g[0]] = g[2]
+
+                if set(cond_op_pred.keys()) != set(cond_op_gt.keys()):
+                    cond_col_err += 1
+                    good=False
+
+                where_op_pred = [cond_op_pred[x] for x in sorted(cond_op_pred.keys())]
+                where_op_gt = [cond_op_gt[x] for x in sorted(cond_op_gt.keys())]
+                if where_op_pred != where_op_gt:
+                    cond_op_err += 1
+                    good=False
+
+                where_val_pred = [cond_val_pred[x] for x in sorted(cond_val_pred.keys())]
+                where_val_gt = [cond_val_gt[x] for x in sorted(cond_val_gt.keys())]
+                if where_val_pred != where_val_gt:
+                    cond_val_err += 1
+                    good=False
+
+            if not good:
+                tot_err += 1
+
+        return np.array((sel_num_err, sel_err, agg_err, cond_num_err, cond_col_err, cond_op_err, cond_val_err , cond_rela_err)), tot_err
+
+
+    def gen_query(self, score, q, col, raw_q, reinforce=False, verbose=False):
+        """
+        :param score:
+        :param q: token-questions
+        :param col: token-headers
+        :param raw_q: original question sequence
+        :return:
+        """
+        def merge_tokens(tok_list, raw_tok_str):
+            tok_str = raw_tok_str# .lower()
+            alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
+            special = {'-LRB-':'(',
+                    '-RRB-':')',
+                    '-LSB-':'[',
+                    '-RSB-':']',
+                    '``':'"',
+                    '\'\'':'"',
+                    '--':u'\u2013'}
+            ret = ''
+            double_quote_appear = 0
+            for raw_tok in tok_list:
+                if not raw_tok:
+                    continue
+                tok = special.get(raw_tok, raw_tok)
+                if tok == '"':
+                    double_quote_appear = 1 - double_quote_appear
+                if len(ret) == 0:
+                    pass
+                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
+                    ret = ret + ' '
+                elif len(ret) > 0 and ret + tok in tok_str:
+                    pass
+                elif tok == '"':
+                    if double_quote_appear:
+                        ret = ret + ' '
+                # elif tok[0] not in alphabet:
+                #     pass
+                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \
+                        and (ret[-1] != '"' or not double_quote_appear):
+                    ret = ret + ' '
+                ret = ret + tok
+            return ret.strip()
+
+        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
+        # [64,4,6], [64,14], ..., [64,4]
+        sel_num_score = sel_num_score.data.cpu().numpy()
+        sel_score = sel_score.data.cpu().numpy()
+        agg_score = agg_score.data.cpu().numpy()
+        where_rela_score = where_rela_score.data.cpu().numpy()
+        ret_queries = []
+        B = len(agg_score)
+        cond_num_score,cond_col_score,cond_op_score,cond_str_score =\
+            [x.data.cpu().numpy() for x in cond_score]
+        for b in range(B):
+            cur_query = {}
+            cur_query['sel'] = []
+            cur_query['agg'] = []
+            sel_num = np.argmax(sel_num_score[b])
+            max_col_idxes = np.argsort(-sel_score[b])[:sel_num]
+            # find the most-probable columns' indexes
+            max_agg_idxes = np.argsort(-agg_score[b])[:sel_num]
+            cur_query['sel'].extend([int(i) for i in max_col_idxes])
+            cur_query['agg'].extend([i[0] for i in max_agg_idxes])
+            cur_query['cond_conn_op'] = np.argmax(where_rela_score[b])
+            cur_query['conds'] = []
+            cond_num = np.argmax(cond_num_score[b])
+            all_toks = ['<BEG>'] + q[b] + ['<END>']
+            max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
+            for idx in range(cond_num):
+                cur_cond = []
+                cur_cond.append(max_idxes[idx]) # where-col
+                cur_cond.append(np.argmax(cond_op_score[b][idx])) # where-op
+                cur_cond_str_toks = []
+                for str_score in cond_str_score[b][idx]:
+                    str_tok = np.argmax(str_score[:len(all_toks)])
+                    str_val = all_toks[str_tok]
+                    if str_val == '<END>':
+                        break
+                    cur_cond_str_toks.append(str_val)
+                cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b]))
+                cur_query['conds'].append(cur_cond)
+            ret_queries.append(cur_query)
+        return ret_queries

+ 227 - 205
sqlnet/utils.py

@@ -1,205 +1,227 @@
-import json
-from sqlnet.lib.dbengine import DBEngine
-import numpy as np
-from tqdm import tqdm
-
-def load_data(sql_paths, table_paths, use_small=False):
-    if not isinstance(sql_paths, list):
-        sql_paths = (sql_paths, )
-    if not isinstance(table_paths, list):
-        table_paths = (table_paths, )
-    sql_data = []
-    table_data = {}
-
-    for SQL_PATH in sql_paths:
-        with open(SQL_PATH, encoding='utf-8') as inf:
-            for idx, line in enumerate(inf):
-                sql = json.loads(line.strip())
-                if use_small and idx >= 1000:
-                    break
-                sql_data.append(sql)
-        print ("Loaded %d data from %s" % (len(sql_data), SQL_PATH))
-
-    for TABLE_PATH in table_paths:
-        with open(TABLE_PATH, encoding='utf-8') as inf:
-            for line in inf:
-                tab = json.loads(line.strip())
-                table_data[tab[u'id']] = tab
-        print ("Loaded %d data from %s" % (len(table_data), TABLE_PATH))
-
-    ret_sql_data = []
-    for sql in sql_data:
-        if sql[u'table_id'] in table_data:
-            ret_sql_data.append(sql)
-
-    return ret_sql_data, table_data
-
-def load_dataset(toy=False, use_small=False, mode='train'):
-    print ("Loading dataset")
-    dev_sql, dev_table = load_data('data/val/val.json', 'data/val/val.tables.json', use_small=use_small)
-    dev_db = 'data/val/val.db'
-    if mode == 'train':
-        train_sql, train_table = load_data('data/train/train.json', 'data/train/train.tables.json', use_small=use_small)
-        train_db = 'data/train/train.db'
-        return train_sql, train_table, train_db, dev_sql, dev_table, dev_db
-    elif mode == 'test':
-        test_sql, test_table = load_data('data/test/test.json', 'data/test/test.tables.json', use_small=use_small)
-        test_db = 'data/test/test.db'
-        return dev_sql, dev_table, dev_db, test_sql, test_table, test_db
-
-def to_batch_seq(sql_data, table_data, idxes, st, ed, ret_vis_data=False):
-    q_seq = []
-    col_seq = []
-    col_num = []
-    ans_seq = []
-    gt_cond_seq = []
-    vis_seq = []
-    sel_num_seq = []
-    for i in range(st, ed):
-        sql = sql_data[idxes[i]]
-        sel_num = len(sql['sql']['sel'])
-        sel_num_seq.append(sel_num)
-        conds_num = len(sql['sql']['conds'])
-        q_seq.append([char for char in sql['question']])
-        col_seq.append([[char for char in header] for header in table_data[sql['table_id']]['header']])
-        col_num.append(len(table_data[sql['table_id']]['header']))
-        ans_seq.append(
-            (
-            len(sql['sql']['agg']),
-            sql['sql']['sel'],
-            sql['sql']['agg'],
-            conds_num,
-            tuple(x[0] for x in sql['sql']['conds']),
-            tuple(x[1] for x in sql['sql']['conds']),
-            sql['sql']['cond_conn_op'],
-            ))
-        gt_cond_seq.append(sql['sql']['conds'])
-        vis_seq.append((sql['question'], table_data[sql['table_id']]['header']))
-    if ret_vis_data:
-        return q_seq, sel_num_seq, col_seq, col_num, ans_seq, gt_cond_seq, vis_seq
-    else:
-        return q_seq, sel_num_seq, col_seq, col_num, ans_seq, gt_cond_seq
-
-def to_batch_seq_test(sql_data, table_data, idxes, st, ed):
-    q_seq = []
-    col_seq = []
-    col_num = []
-    raw_seq = []
-    table_ids = []
-    for i in range(st, ed):
-        sql = sql_data[idxes[i]]
-        q_seq.append([char for char in sql['question']])
-        col_seq.append([[char for char in header] for header in table_data[sql['table_id']]['header']])
-        col_num.append(len(table_data[sql['table_id']]['header']))
-        raw_seq.append(sql['question'])
-        table_ids.append(sql_data[idxes[i]]['table_id'])
-    return q_seq, col_seq, col_num, raw_seq, table_ids
-
-def to_batch_query(sql_data, idxes, st, ed):
-    query_gt = []
-    table_ids = []
-    for i in range(st, ed):
-        sql_data[idxes[i]]['sql']['conds'] = sql_data[idxes[i]]['sql']['conds']
-        query_gt.append(sql_data[idxes[i]]['sql'])
-        table_ids.append(sql_data[idxes[i]]['table_id'])
-    return query_gt, table_ids
-
-def epoch_train(model, optimizer, batch_size, sql_data, table_data):
-    model.train()
-    perm=np.random.permutation(len(sql_data))
-    perm = list(range(len(sql_data)))
-    cum_loss = 0.0
-    for st in tqdm(range(len(sql_data)//batch_size+1)):
-        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
-        st = st * batch_size
-        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq = to_batch_seq(sql_data, table_data, perm, st, ed)
-        # q_seq: char-based sequence of question
-        # gt_sel_num: number of selected columns and aggregation functions
-        # col_seq: char-based column name
-        # col_num: number of headers in one table
-        # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
-        # gt_cond_seq: ground truth of conds
-        gt_where_seq = model.generate_gt_where_seq_test(q_seq, gt_cond_seq)
-        gt_sel_seq = [x[1] for x in ans_seq]
-        score = model.forward(q_seq, col_seq, col_num, gt_where=gt_where_seq, gt_cond=gt_cond_seq, gt_sel=gt_sel_seq, gt_sel_num=gt_sel_num)
-        # sel_num_score, sel_col_score, sel_agg_score, cond_score, cond_rela_score
-
-        # compute loss
-        loss = model.loss(score, ans_seq, gt_where_seq)
-        cum_loss += loss.data.cpu().numpy()*(ed - st)
-        optimizer.zero_grad()
-        loss.backward()
-        optimizer.step()
-    return cum_loss / len(sql_data)
-
-def predict_test(model, batch_size, sql_data, table_data, output_path):
-    model.eval()
-    perm = list(range(len(sql_data)))
-    fw = open(output_path,'w')
-    for st in tqdm(range(len(sql_data)//batch_size+1)):
-        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
-        st = st * batch_size
-        q_seq, col_seq, col_num, raw_q_seq, table_ids = to_batch_seq_test(sql_data, table_data, perm, st, ed)
-        score = model.forward(q_seq, col_seq, col_num)
-        sql_preds = model.gen_query(score, q_seq, col_seq, raw_q_seq)
-        for sql_pred in sql_preds:
-            fw.writelines(json.dumps(sql_pred,ensure_ascii=False).encode('utf-8')+'\n')
-    fw.close()
-
-def epoch_acc(model, batch_size, sql_data, table_data, db_path):
-    engine = DBEngine(db_path)
-    model.eval()
-    perm = list(range(len(sql_data)))
-    badcase = 0
-    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
-    for st in tqdm(range(len(sql_data)//batch_size+1)):
-        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
-        st = st * batch_size
-        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
-            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
-        # q_seq: char-based sequence of question
-        # gt_sel_num: number of selected columns and aggregation functions, new added field
-        # col_seq: char-based column name
-        # col_num: number of headers in one table
-        # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
-        # gt_cond_seq: ground truth of conditions
-        # raw_data: ori question, headers, sql
-        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
-        # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value}
-        raw_q_seq = [x[0] for x in raw_data] # original question
-        try:
-            score = model.forward(q_seq, col_seq, col_num)
-            pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq)
-            # generate predicted format
-            one_err, tot_err = model.check_acc(raw_data, pred_queries, query_gt)
-        except:
-            badcase += 1
-            print ('badcase', badcase)
-            continue
-        one_acc_num += (ed-st-one_err)
-        tot_acc_num += (ed-st-tot_err)
-
-        # Execution Accuracy
-        for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
-            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op'])
-            try:
-                ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op'])
-            except:
-                ret_pred = None
-            ex_acc_num += (ret_gt == ret_pred)
-    return one_acc_num / len(sql_data), tot_acc_num / len(sql_data), ex_acc_num / len(sql_data)
-
-
-def load_word_emb(file_name):
-    print ('Loading word embedding from %s'%file_name)
-    f = open(file_name)
-    ret = json.load(f)
-    f.close()
-    # ret = {}
-    # with open(file_name, encoding='latin') as inf:
-    #     ret = json.load(inf)
-    #     for idx, line in enumerate(inf):
-    #         info = line.strip().split(' ')
-    #         if info[0].lower() not in ret:
-    #             ret[info[0]] = np.array([float(x) for x in info[1:]])
-    return ret
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+'''
+@File    :   utils.py
+@Time    :   2019/06/25 04:01:19
+@Author  :   Liuyuqi 
+@Version :   1.0
+@Contact :   liuyuqi.gov@msn.cn
+@License :   (C)Copyright 2019
+@Desc    :   工具类
+'''
+
+import json
+from sqlnet.lib.dbengine import DBEngine
+import numpy as np
+from tqdm import tqdm
+
+def load_data(sql_paths, table_paths, use_small=False):
+    '''
+    加载数据
+    '''
+    if not isinstance(sql_paths, list):
+        sql_paths = (sql_paths, )
+    if not isinstance(table_paths, list):
+        table_paths = (table_paths, )
+    sql_data = []
+    table_data = {}
+
+    for SQL_PATH in sql_paths:
+        with open(SQL_PATH, encoding='utf-8') as inf:
+            for idx, line in enumerate(inf):
+                sql = json.loads(line.strip())
+                if use_small and idx >= 1000:
+                    break
+                sql_data.append(sql)
+        print ("Loaded %d data from %s" % (len(sql_data), SQL_PATH))
+
+    for TABLE_PATH in table_paths:
+        with open(TABLE_PATH, encoding='utf-8') as inf:
+            for line in inf:
+                tab = json.loads(line.strip())
+                table_data[tab[u'id']] = tab
+        print ("Loaded %d data from %s" % (len(table_data), TABLE_PATH))
+
+    ret_sql_data = []
+    for sql in sql_data:
+        if sql[u'table_id'] in table_data:
+            ret_sql_data.append(sql)
+
+    return ret_sql_data, table_data
+
+def load_dataset(toy=False, use_small=False, mode='train'):
+    print ("Loading dataset")
+    # dev_sql: list, dev_table : dict
+    dev_sql, dev_table = load_data('data/val/val.json', 'data/val/val.tables.json', use_small=use_small)
+    dev_db = 'data/val/val.db'
+    if mode == 'train':
+        train_sql, train_table = load_data('data/train/train.json', 'data/train/train.tables.json', use_small=use_small)
+        train_db = 'data/train/train.db'
+        return train_sql, train_table, train_db, dev_sql, dev_table, dev_db
+    elif mode == 'test':
+        test_sql, test_table = load_data('data/test/test.json', 'data/test/test.tables.json', use_small=use_small)
+        test_db = 'data/test/test.db'
+        return dev_sql, dev_table, dev_db, test_sql, test_table, test_db
+
+def to_batch_seq(sql_data, table_data, idxes, st, ed, ret_vis_data=False):
+    q_seq = []
+    col_seq = []
+    col_num = []
+    ans_seq = []
+    gt_cond_seq = []
+    vis_seq = []
+    sel_num_seq = []
+    for i in range(st, ed):
+        sql = sql_data[idxes[i]]
+        sel_num = len(sql['sql']['sel'])
+        sel_num_seq.append(sel_num)
+        conds_num = len(sql['sql']['conds'])
+        q_seq.append([char for char in sql['question']])
+        col_seq.append([[char for char in header] for header in table_data[sql['table_id']]['header']])
+        col_num.append(len(table_data[sql['table_id']]['header']))
+        ans_seq.append(
+            (
+            len(sql['sql']['agg']),
+            sql['sql']['sel'],
+            sql['sql']['agg'],
+            conds_num,
+            tuple(x[0] for x in sql['sql']['conds']),
+            tuple(x[1] for x in sql['sql']['conds']),
+            sql['sql']['cond_conn_op'],
+            ))
+        gt_cond_seq.append(sql['sql']['conds'])
+        vis_seq.append((sql['question'], table_data[sql['table_id']]['header']))
+    if ret_vis_data:
+        return q_seq, sel_num_seq, col_seq, col_num, ans_seq, gt_cond_seq, vis_seq
+    else:
+        return q_seq, sel_num_seq, col_seq, col_num, ans_seq, gt_cond_seq
+
+def to_batch_seq_test(sql_data, table_data, idxes, st, ed):
+    q_seq = []
+    col_seq = []
+    col_num = []
+    raw_seq = []
+    table_ids = []
+    for i in range(st, ed):
+        sql = sql_data[idxes[i]]
+        q_seq.append([char for char in sql['question']])
+        col_seq.append([[char for char in header] for header in table_data[sql['table_id']]['header']])
+        col_num.append(len(table_data[sql['table_id']]['header']))
+        raw_seq.append(sql['question'])
+        table_ids.append(sql_data[idxes[i]]['table_id'])
+    return q_seq, col_seq, col_num, raw_seq, table_ids
+
+def to_batch_query(sql_data, idxes, st, ed):
+    query_gt = []
+    table_ids = []
+    for i in range(st, ed):
+        sql_data[idxes[i]]['sql']['conds'] = sql_data[idxes[i]]['sql']['conds']
+        query_gt.append(sql_data[idxes[i]]['sql'])
+        table_ids.append(sql_data[idxes[i]]['table_id'])
+    return query_gt, table_ids
+
+def epoch_train(model, optimizer, batch_size, sql_data, table_data):
+    ‘’‘
+    训练
+    model,optimizer
+    batch_size=16
+
+    ’‘’
+    model.train()
+    perm=np.random.permutation(len(sql_data))
+    perm = list(range(len(sql_data)))
+    cum_loss = 0.0
+    for st in tqdm(range(len(sql_data)//batch_size+1)): # range(41522/17=2596)
+        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
+        st = st * batch_size
+        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq = to_batch_seq(sql_data, table_data, perm, st, ed)
+        # q_seq: char-based sequence of question
+        # gt_sel_num: number of selected columns and aggregation functions
+        # col_seq: char-based column name
+        # col_num: number of headers in one table
+        # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
+        # gt_cond_seq: ground truth of conds
+        gt_where_seq = model.generate_gt_where_seq_test(q_seq, gt_cond_seq)
+        gt_sel_seq = [x[1] for x in ans_seq]
+        score = model.forward(q_seq, col_seq, col_num, gt_where=gt_where_seq, gt_cond=gt_cond_seq, gt_sel=gt_sel_seq, gt_sel_num=gt_sel_num)
+        # sel_num_score, sel_col_score, sel_agg_score, cond_score, cond_rela_score
+
+        # compute loss
+        loss = model.loss(score, ans_seq, gt_where_seq)
+        cum_loss += loss.data.cpu().numpy()*(ed - st)
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+    return cum_loss / len(sql_data)
+
+def predict_test(model, batch_size, sql_data, table_data, output_path):
+    model.eval()
+    perm = list(range(len(sql_data)))
+    fw = open(output_path,'w')
+    for st in tqdm(range(len(sql_data)//batch_size+1)):
+        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
+        st = st * batch_size
+        q_seq, col_seq, col_num, raw_q_seq, table_ids = to_batch_seq_test(sql_data, table_data, perm, st, ed)
+        score = model.forward(q_seq, col_seq, col_num)
+        sql_preds = model.gen_query(score, q_seq, col_seq, raw_q_seq)
+        for sql_pred in sql_preds:
+            fw.writelines(json.dumps(sql_pred,ensure_ascii=False).encode('utf-8')+'\n')
+    fw.close()
+
+def epoch_acc(model, batch_size, sql_data, table_data, db_path):
+    engine = DBEngine(db_path)
+    model.eval()
+    perm = list(range(len(sql_data)))
+    badcase = 0
+    one_acc_num, tot_acc_num, ex_acc_num = 0.0, 0.0, 0.0
+    for st in tqdm(range(len(sql_data)//batch_size+1)):
+        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
+        st = st * batch_size
+        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
+            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
+        # q_seq: char-based sequence of question
+        # gt_sel_num: number of selected columns and aggregation functions, new added field
+        # col_seq: char-based column name
+        # col_num: number of headers in one table
+        # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
+        # gt_cond_seq: ground truth of conditions
+        # raw_data: ori question, headers, sql
+        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
+        # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value}
+        raw_q_seq = [x[0] for x in raw_data] # original question
+        try:
+            score = model.forward(q_seq, col_seq, col_num)
+            pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq)
+            # generate predicted format
+            one_err, tot_err = model.check_acc(raw_data, pred_queries, query_gt)
+        except:
+            badcase += 1
+            print ('badcase', badcase)
+            continue
+        one_acc_num += (ed-st-one_err)
+        tot_acc_num += (ed-st-tot_err)
+
+        # Execution Accuracy
+        for sql_gt, sql_pred, tid in zip(query_gt, pred_queries, table_ids):
+            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'], sql_gt['cond_conn_op'])
+            try:
+                ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'], sql_pred['cond_conn_op'])
+            except:
+                ret_pred = None
+            ex_acc_num += (ret_gt == ret_pred)
+    return one_acc_num / len(sql_data), tot_acc_num / len(sql_data), ex_acc_num / len(sql_data)
+
+
+def load_word_emb(file_name):
+    print ('Loading word embedding from %s'%file_name)
+    f = open(file_name)
+    ret = json.load(f)
+    f.close()
+    # ret = {}
+    # with open(file_name, encoding='latin') as inf:
+    #     ret = json.load(inf)
+    #     for idx, line in enumerate(inf):
+    #         info = line.strip().split(' ')
+    #         if info[0].lower() not in ret:
+    #             ret[info[0]] = np.array([float(x) for x in info[1:]])
+    return ret

+ 40 - 40
test.py

@@ -1,40 +1,40 @@
-import torch
-from sqlnet.utils import *
-from sqlnet.model.sqlnet import SQLNet
-import argparse
-
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--gpu', action='store_true', help='Whether use gpu')
-    parser.add_argument('--toy', action='store_true', help='Small batchsize for fast debugging.')
-    parser.add_argument('--ca', action='store_true', help='Whether use column attention.')
-    parser.add_argument('--train_emb', action='store_true', help='Use trained word embedding for SQLNet.')
-    parser.add_argument('--output_dir', type=str, default='', help='Output path of prediction result')
-    args = parser.parse_args()
-
-    n_word=300
-    if args.toy:
-        use_small=True
-        gpu=args.gpu
-        batch_size=16
-    else:
-        use_small=False
-        gpu=args.gpu
-        batch_size=64
-
-    dev_sql, dev_table, dev_db, test_sql, test_table, test_db = load_dataset(use_small=use_small, mode='test')
-
-    word_emb = load_word_emb('data/char_embedding')
-    model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
-
-    model_path = 'saved_model/best_model'
-    print ("Loading from %s" % model_path)
-    model.load_state_dict(torch.load(model_path))
-    print ("Loaded model from %s" % model_path)
-
-    dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
-    print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
-
-    print ("Start to predict test set")
-    predict_test(model, batch_size, test_sql, test_table, args.output_dir)
-    print ("Output path of prediction result is %s" % args.output_dir)
+import torch
+from sqlnet.utils import *
+from sqlnet.model.sqlnet import SQLNet
+import argparse
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--gpu', action='store_true', help='Whether use gpu')
+    parser.add_argument('--toy', action='store_true', help='Small batchsize for fast debugging.')
+    parser.add_argument('--ca', action='store_true', help='Whether use column attention.')
+    parser.add_argument('--train_emb', action='store_true', help='Use trained word embedding for SQLNet.')
+    parser.add_argument('--output_dir', type=str, default='', help='Output path of prediction result')
+    args = parser.parse_args()
+
+    n_word=300
+    if args.toy:
+        use_small=True
+        gpu=args.gpu
+        batch_size=16
+    else:
+        use_small=False
+        gpu=args.gpu
+        batch_size=64
+
+    dev_sql, dev_table, dev_db, test_sql, test_table, test_db = load_dataset(use_small=use_small, mode='test')
+
+    word_emb = load_word_emb('data/char_embedding.json')
+    model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
+
+    model_path = 'saved_model/best_model'
+    print ("Loading from %s" % model_path)
+    model.load_state_dict(torch.load(model_path))
+    print ("Loaded model from %s" % model_path)
+
+    dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
+    print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
+
+    print ("Start to predict test set")
+    predict_test(model, batch_size, test_sql, test_table, args.output_dir)
+    print ("Output path of prediction result is %s" % args.output_dir)

+ 113 - 100
train.py

@@ -1,100 +1,113 @@
-import torch
-from sqlnet.utils import *
-from sqlnet.model.sqlnet import SQLNet
-
-import argparse
-
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser()
-    parser.add_argument('--bs', type=int, default=16, help='Batch size')
-    parser.add_argument('--epoch', type=int, default=100, help='Epoch number')
-    parser.add_argument('--gpu', action='store_true', help='Whether use gpu to train')
-    parser.add_argument('--toy', action='store_true', help='If set, use small data for fast debugging')
-    parser.add_argument('--ca', action='store_true', help='Whether use column attention')
-    parser.add_argument('--train_emb', action='store_true', help='Train word embedding for SQLNet')
-    parser.add_argument('--restore', action='store_true', help='Whether restore trained model')
-    parser.add_argument('--logdir', type=str, default='', help='Path of save experiment log')
-    args = parser.parse_args()
-
-    n_word=300
-    if args.toy:
-        use_small=True
-        gpu=args.gpu
-        batch_size=16
-    else:
-        use_small=False
-        gpu=args.gpu
-        batch_size=args.bs
-    learning_rate = 1e-3
-
-    # load dataset
-    train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
-
-    word_emb = load_word_emb('data/char_embedding.json')
-    model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
-    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
-
-    if args.restore:
-        model_path= 'saved_model/best_model'
-        print ("Loading trained model from %s" % model_path)
-        model.load_state_dict(torch.load(model_path))
-
-    # used to record best score of each sub-task
-    best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv, best_wr = 0, 0, 0, 0, 0, 0, 0, 0
-    best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx, best_wr_idx = 0, 0, 0, 0, 0, 0, 0, 0
-    best_lf, best_lf_idx = 0.0, 0
-    best_ex, best_ex_idx = 0.0, 0
-
-    print ("#"*20+"  Star to Train  " + "#"*20)
-    for i in range(args.epoch):
-        print ('Epoch %d'%(i+1))
-        # train on the train dataset
-        train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
-        # evaluate on the dev dataset
-        dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
-        # accuracy of each sub-task
-        print ('Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
-            dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7]))
-        # save the best model
-        if dev_acc[1] > best_lf:
-            best_lf = dev_acc[1]
-            best_lf_idx = i + 1
-            torch.save(model.state_dict(), 'saved_model/best_model')
-        if dev_acc[2] > best_ex:
-            best_ex = dev_acc[2]
-            best_ex_idx = i + 1
-
-        # record the best score of each sub-task
-        if True:
-            if dev_acc[0][0] > best_sn:
-                best_sn = dev_acc[0][0]
-                best_sn_idx = i+1
-            if dev_acc[0][1] > best_sc:
-                best_sc = dev_acc[0][1]
-                best_sc_idx = i+1
-            if dev_acc[0][2] > best_sa:
-                best_sa = dev_acc[0][2]
-                best_sa_idx = i+1
-            if dev_acc[0][3] > best_wn:
-                best_wn = dev_acc[0][3]
-                best_wn_idx = i+1
-            if dev_acc[0][4] > best_wc:
-                best_wc = dev_acc[0][4]
-                best_wc_idx = i+1
-            if dev_acc[0][5] > best_wo:
-                best_wo = dev_acc[0][5]
-                best_wo_idx = i+1
-            if dev_acc[0][6] > best_wv:
-                best_wv = dev_acc[0][6]
-                best_wv_idx = i+1
-            if dev_acc[0][7] > best_wr:
-                best_wr = dev_acc[0][7]
-                best_wr_idx = i+1
-        print ('Train loss = %.3f' % train_loss)
-        print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
-        print ('Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx))
-        print ('Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx))
-        if (i+1) % 10 == 0:
-            print ('Best val acc: %s\nOn epoch individually %s'%(
-                    (best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
-                    (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx)))
+#!/usr/bin/env python
+# -*- encoding: utf-8 -*-
+'''
+@File    :   train.py
+@Time    :   2019/06/25 04:00:53
+@Author  :   Liuyuqi 
+@Version :   1.0
+@Contact :   liuyuqi.gov@msn.cn
+@License :   (C)Copyright 2019
+@Desc    :   
+'''
+
+import torch
+from sqlnet.utils import *
+from sqlnet.model.sqlnet import SQLNet
+
+import argparse
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--bs', type=int, default=16, help='Batch size')#
+    parser.add_argument('--epoch', type=int, default=100, help='Epoch number')
+    parser.add_argument('--gpu', action='store_true', help='Whether use gpu to train')#
+    parser.add_argument('--toy', action='store_true', help='If set, use small data for fast debugging')
+    parser.add_argument('--ca', action='store_true', help='Whether use column attention')#
+    parser.add_argument('--train_emb', action='store_true', help='Train word embedding for SQLNet')
+    parser.add_argument('--restore', action='store_true', help='Whether restore trained model')
+    parser.add_argument('--logdir', type=str, default='', help='Path of save experiment log')
+    args = parser.parse_args()
+
+    n_word=300
+    if args.toy:
+        use_small=True
+        gpu=args.gpu
+        batch_size=16
+    else:
+        use_small=False
+        gpu=args.gpu
+        batch_size=args.bs
+    learning_rate = 1e-3
+
+    # load dataset 加载训练数据和测试数据
+    train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
+
+    # word_emb 字典类型。
+    word_emb = load_word_emb('data/char_embedding.json')
+    model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
+    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
+
+    if args.restore:
+        model_path= 'saved_model/best_model'
+        print ("Loading trained model from %s" % model_path)
+        model.load_state_dict(torch.load(model_path))
+
+    # used to record best score of each sub-task
+    best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv, best_wr = 0, 0, 0, 0, 0, 0, 0, 0
+    best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx, best_wr_idx = 0, 0, 0, 0, 0, 0, 0, 0
+    best_lf, best_lf_idx = 0.0, 0
+    best_ex, best_ex_idx = 0.0, 0
+
+    print ("#"*20+"  Star to Train  " + "#"*20)
+    for i in range(args.epoch):# range(100)
+        print ('Epoch %d'%(i+1))
+        # train on the train dataset
+        train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
+        # evaluate on the dev dataset
+        dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
+        # accuracy of each sub-task
+        print ('Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
+            dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7]))
+        # save the best model
+        if dev_acc[1] > best_lf:
+            best_lf = dev_acc[1]
+            best_lf_idx = i + 1
+            torch.save(model.state_dict(), 'saved_model/best_model')
+        if dev_acc[2] > best_ex:
+            best_ex = dev_acc[2]
+            best_ex_idx = i + 1
+
+        # record the best score of each sub-task
+        if True:
+            if dev_acc[0][0] > best_sn:
+                best_sn = dev_acc[0][0]
+                best_sn_idx = i+1
+            if dev_acc[0][1] > best_sc:
+                best_sc = dev_acc[0][1]
+                best_sc_idx = i+1
+            if dev_acc[0][2] > best_sa:
+                best_sa = dev_acc[0][2]
+                best_sa_idx = i+1
+            if dev_acc[0][3] > best_wn:
+                best_wn = dev_acc[0][3]
+                best_wn_idx = i+1
+            if dev_acc[0][4] > best_wc:
+                best_wc = dev_acc[0][4]
+                best_wc_idx = i+1
+            if dev_acc[0][5] > best_wo:
+                best_wo = dev_acc[0][5]
+                best_wo_idx = i+1
+            if dev_acc[0][6] > best_wv:
+                best_wv = dev_acc[0][6]
+                best_wv_idx = i+1
+            if dev_acc[0][7] > best_wr:
+                best_wr = dev_acc[0][7]
+                best_wr_idx = i+1
+        print ('Train loss = %.3f' % train_loss)
+        print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
+        print ('Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx))
+        print ('Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx))
+        if (i+1) % 10 == 0:
+            print ('Best val acc: %s\nOn epoch individually %s'%(
+                    (best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
+                    (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx)))