waynesun 4 years ago
commit
9a9c2a4c7f

+ 29 - 0
LICENSE

@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2017, Xiaojun Xu, Chang Liu and Dawn Song
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+  list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+  this list of conditions and the following disclaimer in the documentation
+  and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+  contributors may be used to endorse or promote products derived from
+  this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

+ 52 - 0
README.md

@@ -0,0 +1,52 @@
+## Introduction
+
+This baseline method is developed and refined based on <a href="https://arxiv.org/abs/1711.04436">SQLNet</a>, which is a baseline model in <a href="https://github.com/salesforce/WikiSQL">WikiSQL</a>.
+
+The model decouples the task of generating a whole SQL into several sub-tasks, including select-number, select-column, select-aggregation, condition-number, condition-column and so on.
+
+Simple model structure shows here, implementation details could refer to the origin <a href="https://arxiv.org/abs/1711.04436">paper</a>.
+
+<div align="middle"><img src="https://github.com/ZhuiyiTechnology/nl2sql_baseline/blob/master/img/detailed_structure.png"width="80%" ></div>
+
+The difference between SQLNet and this baseline model is, Select-Number and Where-Relationship sub-tasks are added to adapt this Chinese NL2SQL dataset better.
+
+## Dependencies
+
+ - Python 2.7
+ - torch 1.0.1
+ - tqdm
+
+## Start to train
+
+Firstly, download the provided datasets at ~/data_nl2sql/ including train.json, train.tables.json, dev.json, dev.tables.json and char_embedding.
+```
+mkdir ~/nl2sql
+cd ~/nl2sql/
+git clone https://github.com/ZhuiyiTechnology/nl2sql_baseline.git
+cp ~/data_nl2sql/* ~/nl2sql/data
+sh ./start_train.py 0 128
+```
+while the first parameter 0 means gpu number, the second parameter means batch size.
+
+## Start to evaluate
+
+To evaluate on dev.json or test.json, make sure trained model is ready, then run
+```
+sh ./start_test.py 0 pred_example
+```
+while the first parameter 0 means gpu number, the second parameter means the output path of prediction.
+
+## Experiment result
+
+We have run experiments several times, achiving avegrage 27.5% logic form accuracy on the dev dataset.
+
+And we found the main challenges of this datasets containing poor condition value prediction, select column and condition column not mentioned in NL question, inconsistent condition relationship representation between NL question and SQL, etc. All these challenges could not be solve by existing baseline and SOTA models.
+
+Correspondingly, this baseline model achieves only 77% accuracy on condition column and 62% accuracy on condition value respectively even on the training set, which require contestants to pay attention to.
+
+## Related resources:
+https://github.com/salesforce/WikiSQL
+
+https://yale-lily.github.io/spider
+
+ <a href="https://arxiv.org/pdf/1804.08338.pdf">Semantic Parsing with Syntax- and Table-Aware SQL Generation</a>

BIN
img/detailed_structure.png


BIN
img/structure.png


BIN
img/wikisql_example.png


+ 0 - 0
saved_model/.keep


+ 0 - 0
saved_model/best_model


+ 0 - 0
sqlnet/__init__.py


+ 0 - 0
sqlnet/lib/__init__.py


+ 52 - 0
sqlnet/lib/dbengine.py

@@ -0,0 +1,52 @@
+import records
+import re
+from babel.numbers import parse_decimal, NumberFormatError
+
+
+schema_re = re.compile(r'\((.+)\)')
+num_re = re.compile(r'[-+]?\d*\.\d+|\d+')
+
+agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
+cond_ops = ['=', '>', '<', 'OP']
+
+class DBEngine:
+
+    def __init__(self, fdb):
+        #fdb = 'data/test.db'
+        self.db = records.Database('sqlite:///{}'.format(fdb))
+
+    def execute_query(self, table_id, query, *args, **kwargs):
+        return self.execute(table_id, query.sel_index, query.agg_index, query.conditions, *args, **kwargs)
+
+    def execute(self, table_id, select_index, aggregation_index, conditions, lower=True):
+        if not table_id.startswith('table'):
+            table_id = 'table_{}'.format(table_id.replace('-', '_'))
+        table_info = self.db.query('SELECT sql from sqlite_master WHERE tbl_name = :name', name=table_id).all()[0].sql.replace('\n','')
+        schema_str = schema_re.findall(table_info)[0]
+        schema = {}
+        for tup in schema_str.split(', '):
+            c, t = tup.split()
+            schema[c] = t
+        select = 'col{}'.format(select_index)
+        agg = agg_ops[aggregation_index]
+        if agg:
+            select = '{}({})'.format(agg, select)
+        where_clause = []
+        where_map = {}
+        for col_index, op, val in conditions:
+            if lower and (isinstance(val, str) or isinstance(val, unicode)):
+                val = val.lower()
+            if schema['col{}'.format(col_index)] == 'real' and not isinstance(val, (int, float)):
+                try:
+                    val = float(parse_decimal(val))
+                except NumberFormatError as e:
+                    val = float(num_re.findall(val)[0])
+            where_clause.append('col{} {} :col{}'.format(col_index, cond_ops[op], col_index))
+            where_map['col{}'.format(col_index)] = val
+        where_str = ''
+        if where_clause:
+            where_str = 'WHERE ' + ' AND '.join(where_clause)
+        query = 'SELECT {} AS result FROM {} {}'.format(select, table_id, where_str)
+        #print query
+        out = self.db.query(query, **where_map)
+        return [o.result for o in out]

+ 0 - 0
sqlnet/model/__init__.py


+ 0 - 0
sqlnet/model/modules/__init__.py


+ 56 - 0
sqlnet/model/modules/aggregator_predict.py

@@ -0,0 +1,56 @@
+import json
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from net_utils import run_lstm, col_name_encode
+
+
+
+class AggPredictor(nn.Module):
+    def __init__(self, N_word, N_h, N_depth, use_ca):
+        super(AggPredictor, self).__init__()
+        self.use_ca = use_ca
+
+        self.agg_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        if use_ca:
+            print "Using column attention on aggregator predicting"
+            self.agg_col_name_enc = nn.LSTM(input_size=N_word,
+                    hidden_size=N_h/2, num_layers=N_depth,
+                    batch_first=True, dropout=0.3, bidirectional=True)
+            self.agg_att = nn.Linear(N_h, N_h)
+        else:
+            print "Not using column attention on aggregator predicting"
+            self.agg_att = nn.Linear(N_h, 1)
+        self.agg_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(), nn.Linear(N_h, 6))
+        self.softmax = nn.Softmax(dim=-1)
+        self.agg_out_K = nn.Linear(N_h, N_h)
+        self.col_out_col = nn.Linear(N_h, N_h)
+
+    def forward(self, x_emb_var, x_len, col_inp_var=None, col_name_len=None,
+            col_len=None, col_num=None, gt_sel=None, gt_sel_num=None):
+        B = len(x_emb_var)
+        max_x_len = max(x_len)
+
+        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.agg_col_name_enc)
+        h_enc, _ = run_lstm(self.agg_lstm, x_emb_var, x_len)
+
+        col_emb = []
+        for b in range(B):
+            cur_col_emb = torch.stack([e_col[b,x] for x in gt_sel[b]] + [e_col[b,0]] * (4-len(gt_sel[b])))
+            col_emb.append(cur_col_emb)
+        col_emb = torch.stack(col_emb)
+
+        att_val = torch.matmul(self.agg_att(h_enc).unsqueeze(1), col_emb.unsqueeze(3)).squeeze() # .transpose(1,2))
+
+        for idx, num in enumerate(x_len):
+            if num < max_x_len:
+                att_val[idx, num:] = -100
+        att = self.softmax(att_val.view(B*4, -1)).view(B, 4, -1)
+
+        K_agg = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
+        agg_score = self.agg_out(self.agg_out_K(K_agg) + self.col_out_col(col_emb)).squeeze()
+        return agg_score

+ 47 - 0
sqlnet/model/modules/net_utils.py

@@ -0,0 +1,47 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from torch.autograd import Variable
+
+def run_lstm(lstm, inp, inp_len, hidden=None):
+    # Run the LSTM using packed sequence.
+    # This requires to first sort the input according to its length.
+    sort_perm = np.array(sorted(range(len(inp_len)),
+        key=lambda k:inp_len[k], reverse=True))
+    sort_inp_len = inp_len[sort_perm]
+    sort_perm_inv = np.argsort(sort_perm)
+    if inp.is_cuda:
+        sort_perm = torch.LongTensor(sort_perm).cuda()
+        sort_perm_inv = torch.LongTensor(sort_perm_inv).cuda()
+
+    lstm_inp = nn.utils.rnn.pack_padded_sequence(inp[sort_perm],
+            sort_inp_len, batch_first=True)
+    if hidden is None:
+        lstm_hidden = None
+    else:
+        lstm_hidden = (hidden[0][:, sort_perm], hidden[1][:, sort_perm])
+
+    sort_ret_s, sort_ret_h = lstm(lstm_inp, lstm_hidden)
+    ret_s = nn.utils.rnn.pad_packed_sequence(
+            sort_ret_s, batch_first=True)[0][sort_perm_inv]
+    ret_h = (sort_ret_h[0][:, sort_perm_inv], sort_ret_h[1][:, sort_perm_inv])
+    return ret_s, ret_h
+
+
+def col_name_encode(name_inp_var, name_len, col_len, enc_lstm):
+    #Encode the columns.
+    #The embedding of a column name is the last state of its LSTM output.
+    name_hidden, _ = run_lstm(enc_lstm, name_inp_var, name_len)
+    name_out = name_hidden[tuple(range(len(name_len))), name_len-1]
+    ret = torch.FloatTensor(
+            len(col_len), max(col_len), name_out.size()[1]).zero_()
+    if name_out.is_cuda:
+        ret = ret.cuda()
+
+    st = 0
+    for idx, cur_len in enumerate(col_len):
+        ret[idx, :cur_len] = name_out.data[st:st+cur_len]
+        st += cur_len
+    ret_var = Variable(ret)
+
+    return ret_var, col_len

+ 67 - 0
sqlnet/model/modules/select_number.py

@@ -0,0 +1,67 @@
+import json
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from net_utils import run_lstm, col_name_encode
+
+class SelNumPredictor(nn.Module):
+    def __init__(self, N_word, N_h, N_depth, use_ca):
+        super(SelNumPredictor, self).__init__()
+        self.N_h = N_h
+        self.use_ca = use_ca
+
+        self.sel_num_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                                    num_layers=N_depth, batch_first=True,
+                                    dropout=0.3, bidirectional=True)
+        self.sel_num_att = nn.Linear(N_h, 1)
+        self.sel_num_col_att = nn.Linear(N_h, 1)
+        self.sel_num_out = nn.Sequential(nn.Linear(N_h, N_h),
+                                         nn.Tanh(), nn.Linear(N_h,4))
+        self.softmax = nn.Softmax(dim=-1)
+        self.sel_num_col2hid1 = nn.Linear(N_h, 2 * N_h)
+        self.sel_num_col2hid2 = nn.Linear(N_h, 2 * N_h)
+
+
+        if self.use_ca:
+            print "Using column attention on select number predicting"
+
+    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num):
+        B = len(x_len)
+        max_x_len = max(x_len)
+
+        # Predict the number of select part
+        # First use column embeddings to calculate the initial hidden unit
+        # Then run the LSTM and predict select number
+        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
+                                             col_len, self.sel_num_lstm)
+        num_col_att_val = self.sel_num_col_att(e_num_col).squeeze()
+        for idx, num in enumerate(col_num):
+            if num < max(col_num):
+                num_col_att_val[idx, num:] = -1000000
+        num_col_att = self.softmax(num_col_att_val)
+        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
+        sel_num_h1 = self.sel_num_col2hid1(K_num_col).view(B, 4, self.N_h/2).transpose(0,1).contiguous()
+        sel_num_h2 = self.sel_num_col2hid2(K_num_col).view(B, 4, self.N_h/2).transpose(0,1).contiguous()
+
+        h_num_enc, _ = run_lstm(self.sel_num_lstm, x_emb_var, x_len,
+                                hidden=(sel_num_h1, sel_num_h2))
+
+        num_att_val = self.sel_num_att(h_num_enc).squeeze()
+        for idx, num in enumerate(x_len):
+            if num < max_x_len:
+                num_att_val[idx, num:] = -1000000
+        num_att = self.softmax(num_att_val)
+
+        K_sel_num = (h_num_enc * num_att.unsqueeze(2).expand_as(
+            h_num_enc)).sum(1)
+        sel_num_score = self.sel_num_out(K_sel_num)
+        return sel_num_score
+
+
+
+
+
+
+

+ 54 - 0
sqlnet/model/modules/selection_predict.py

@@ -0,0 +1,54 @@
+import json
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from net_utils import run_lstm, col_name_encode
+
+class SelPredictor(nn.Module):
+    def __init__(self, N_word, N_h, N_depth, max_tok_num, use_ca):
+        super(SelPredictor, self).__init__()
+        self.use_ca = use_ca
+        self.max_tok_num = max_tok_num
+        self.sel_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        if use_ca:
+            print "Using column attention on selection predicting"
+            self.sel_att = nn.Linear(N_h, N_h)
+        else:
+            print "Not using column attention on selection predicting"
+            self.sel_att = nn.Linear(N_h, 1)
+        self.sel_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        self.sel_out_K = nn.Linear(N_h, N_h)
+        self.sel_out_col = nn.Linear(N_h, N_h)
+        self.sel_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1))
+        self.softmax = nn.Softmax(dim=-1)
+
+
+    def forward(self, x_emb_var, x_len, col_inp_var,
+            col_name_len, col_len, col_num):
+        # Based on number of selections to predict select-column
+        B = len(x_emb_var)
+        max_x_len = max(x_len)
+
+        e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) # [bs, col_num, hid]
+        h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) # [bs, seq_len, hid]
+
+        att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) # [bs, col_num, seq_len]
+        for idx, num in enumerate(x_len):
+            if num < max_x_len:
+                att_val[idx, :, num:] = -100
+        att = self.softmax(att_val.view((-1, max_x_len))).view(B, -1, max_x_len)
+        K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
+
+        sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + self.sel_out_col(e_col) ).squeeze()
+        max_col_num = max(col_num)
+        for idx, num in enumerate(col_num):
+            if num < max_col_num:
+                sel_score[idx, num:] = -100
+
+        return sel_score

+ 122 - 0
sqlnet/model/modules/seq2sql_condition_predict.py

@@ -0,0 +1,122 @@
+import json
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from net_utils import run_lstm
+
+class Seq2SQLCondPredictor(nn.Module):
+    def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, gpu):
+        super(Seq2SQLCondPredictor, self).__init__()
+        print "Seq2SQL where prediction"
+        self.N_h = N_h
+        self.max_tok_num = max_tok_num
+        self.max_col_num = max_col_num
+        self.gpu = gpu
+
+        self.cond_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        self.cond_decoder = nn.LSTM(input_size=self.max_tok_num,
+                hidden_size=N_h, num_layers=N_depth,
+                batch_first=True, dropout=0.3)
+
+        self.cond_out_g = nn.Linear(N_h, N_h)
+        self.cond_out_h = nn.Linear(N_h, N_h)
+        self.cond_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1))
+
+        self.softmax = nn.Softmax()
+
+
+    def gen_gt_batch(self, tok_seq, gen_inp=True):
+        # If gen_inp: generate the input token sequence (removing <END>)
+        # Otherwise: generate the output token sequence (removing <BEG>)
+        B = len(tok_seq)
+        ret_len = np.array([len(one_tok_seq)-1 for one_tok_seq in tok_seq])
+        max_len = max(ret_len)
+        ret_array = np.zeros((B, max_len, self.max_tok_num), dtype=np.float32)
+        for b, one_tok_seq in enumerate(tok_seq):
+            out_one_tok_seq = one_tok_seq[:-1] if gen_inp else one_tok_seq[1:]
+            for t, tok_id in enumerate(out_one_tok_seq):
+                ret_array[b, t, tok_id] = 1
+
+        ret_inp = torch.from_numpy(ret_array)
+        if self.gpu:
+            ret_inp = ret_inp.cuda()
+        ret_inp_var = Variable(ret_inp) #[B, max_len, max_tok_num]
+
+        return ret_inp_var, ret_len
+
+
+    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
+            col_num, gt_where, gt_cond, reinforce):
+        max_x_len = max(x_len)
+        B = len(x_len)
+
+        h_enc, hidden = run_lstm(self.cond_lstm, x_emb_var, x_len)
+        decoder_hidden = tuple(torch.cat((hid[:2], hid[2:]),dim=2) 
+                for hid in hidden)
+        if gt_where is not None:
+            gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where, gen_inp=True)
+            g_s, _ = run_lstm(self.cond_decoder,
+                    gt_tok_seq, gt_tok_len, decoder_hidden)
+
+            h_enc_expand = h_enc.unsqueeze(1)
+            g_s_expand = g_s.unsqueeze(2)
+            cond_score = self.cond_out( self.cond_out_h(h_enc_expand) +
+                    self.cond_out_g(g_s_expand) ).squeeze()
+            for idx, num in enumerate(x_len):
+                if num < max_x_len:
+                    cond_score[idx, :, num:] = -100
+        else:
+            h_enc_expand = h_enc.unsqueeze(1)
+            scores = []
+            choices = []
+            done_set = set()
+
+            t = 0
+            init_inp = np.zeros((B, 1, self.max_tok_num), dtype=np.float32)
+            init_inp[:,0,7] = 1  #Set the <BEG> token
+            if self.gpu:
+                cur_inp = Variable(torch.from_numpy(init_inp).cuda())
+            else:
+                cur_inp = Variable(torch.from_numpy(init_inp))
+            cur_h = decoder_hidden
+            while len(done_set) < B and t < 100:
+                g_s, cur_h = self.cond_decoder(cur_inp, cur_h)
+                g_s_expand = g_s.unsqueeze(2)
+
+                cur_cond_score = self.cond_out(self.cond_out_h(h_enc_expand) +
+                        self.cond_out_g(g_s_expand)).squeeze()
+                for b, num in enumerate(x_len):
+                    if num < max_x_len:
+                        cur_cond_score[b, num:] = -100
+                scores.append(cur_cond_score)
+
+                if not reinforce:
+                    _, ans_tok_var = cur_cond_score.view(B, max_x_len).max(1)
+                    ans_tok_var = ans_tok_var.unsqueeze(1)
+                else:
+                    ans_tok_var = self.softmax(cur_cond_score).multinomial()
+                    choices.append(ans_tok_var)
+                ans_tok = ans_tok_var.data.cpu()
+                if self.gpu:  #To one-hot
+                    cur_inp = Variable(torch.zeros(
+                        B, self.max_tok_num).scatter_(1, ans_tok, 1).cuda())
+                else:
+                    cur_inp = Variable(torch.zeros(
+                        B, self.max_tok_num).scatter_(1, ans_tok, 1))
+                cur_inp = cur_inp.unsqueeze(1)
+
+                for idx, tok in enumerate(ans_tok.squeeze()):
+                    if tok == 1:  #Find the <END> token
+                        done_set.add(idx)
+                t += 1
+
+            cond_score = torch.stack(scores, 1)
+
+        if reinforce:
+            return cond_score, choices
+        else:
+            return cond_score

+ 289 - 0
sqlnet/model/modules/sqlnet_condition_predict.py

@@ -0,0 +1,289 @@
+import json
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from net_utils import run_lstm, col_name_encode
+
+class SQLNetCondPredictor(nn.Module):
+    def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, use_ca, gpu):
+        super(SQLNetCondPredictor, self).__init__()
+        self.N_h = N_h
+        self.max_tok_num = max_tok_num
+        self.max_col_num = max_col_num
+        self.gpu = gpu
+        self.use_ca = use_ca
+
+        self.cond_num_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        self.cond_num_att = nn.Linear(N_h, 1)
+        self.cond_num_out = nn.Sequential(nn.Linear(N_h, N_h),
+                nn.Tanh(), nn.Linear(N_h, 5))
+        self.cond_num_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        self.cond_num_col_att = nn.Linear(N_h, 1)
+        self.cond_num_col2hid1 = nn.Linear(N_h, 2*N_h)
+        self.cond_num_col2hid2 = nn.Linear(N_h, 2*N_h)
+
+        self.cond_col_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        if use_ca:
+            print "Using column attention on where predicting"
+            self.cond_col_att = nn.Linear(N_h, N_h)
+        else:
+            print "Not using column attention on where predicting"
+            self.cond_col_att = nn.Linear(N_h, 1)
+        self.cond_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        self.cond_col_out_K = nn.Linear(N_h, N_h)
+        self.cond_col_out_col = nn.Linear(N_h, N_h)
+        self.cond_col_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1))
+
+        self.cond_op_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        if use_ca:
+            self.cond_op_att = nn.Linear(N_h, N_h)
+        else:
+            self.cond_op_att = nn.Linear(N_h, 1)
+        self.cond_op_out_K = nn.Linear(N_h, N_h)
+        self.cond_op_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        self.cond_op_out_col = nn.Linear(N_h, N_h)
+        self.cond_op_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(),
+                nn.Linear(N_h, 4))
+
+        self.cond_str_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        self.cond_str_decoder = nn.LSTM(input_size=self.max_tok_num,
+                hidden_size=N_h, num_layers=N_depth,
+                batch_first=True, dropout=0.3)
+        self.cond_str_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                num_layers=N_depth, batch_first=True,
+                dropout=0.3, bidirectional=True)
+        self.cond_str_out_g = nn.Linear(N_h, N_h)
+        self.cond_str_out_h = nn.Linear(N_h, N_h)
+        self.cond_str_out_col = nn.Linear(N_h, N_h)
+        self.cond_str_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1))
+
+        self.softmax = nn.Softmax(dim=-1)
+
+
+    def gen_gt_batch(self, split_tok_seq):
+        B = len(split_tok_seq)
+        max_len = max([max([len(tok) for tok in tok_seq]+[0]) for 
+            tok_seq in split_tok_seq]) - 1 # The max seq len in the batch.
+        if max_len < 1:
+            max_len = 1
+        ret_array = np.zeros((
+            B, 4, max_len, self.max_tok_num), dtype=np.float32)
+        ret_len = np.zeros((B, 4))
+        for b, tok_seq in enumerate(split_tok_seq):
+            idx = 0
+            for idx, one_tok_seq in enumerate(tok_seq):
+                out_one_tok_seq = one_tok_seq[:-1]
+                ret_len[b, idx] = len(out_one_tok_seq)
+                for t, tok_id in enumerate(out_one_tok_seq):
+                    ret_array[b, idx, t, tok_id] = 1
+            if idx < 3:
+                ret_array[b, idx+1:, 0, 1] = 1
+                ret_len[b, idx+1:] = 1
+
+        ret_inp = torch.from_numpy(ret_array)
+        if self.gpu:
+            ret_inp = ret_inp.cuda()
+        ret_inp_var = Variable(ret_inp)
+
+        return ret_inp_var, ret_len #[B, IDX, max_len, max_tok_num]
+
+
+    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len,
+            col_len, col_num, gt_where, gt_cond, reinforce):
+        max_x_len = max(x_len)
+        B = len(x_len)
+        if reinforce:
+            raise NotImplementedError('Our model doesn\'t have RL')
+
+        # Predict the number of conditions
+        # First use column embeddings to calculate the initial hidden unit
+        # Then run the LSTM and predict condition number.
+        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_num_name_enc)
+        num_col_att_val = self.cond_num_col_att(e_num_col).squeeze()
+        for idx, num in enumerate(col_num):
+            if num < max(col_num):
+                num_col_att_val[idx, num:] = -100
+        num_col_att = self.softmax(num_col_att_val)
+        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
+        cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(
+                B, 4, self.N_h/2).transpose(0, 1).contiguous()
+        cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(
+                B, 4, self.N_h/2).transpose(0, 1).contiguous()
+
+        h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len,
+                hidden=(cond_num_h1, cond_num_h2))
+
+        num_att_val = self.cond_num_att(h_num_enc).squeeze()
+
+        for idx, num in enumerate(x_len):
+            if num < max_x_len:
+                num_att_val[idx, num:] = -100
+        num_att = self.softmax(num_att_val)
+
+        K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
+        cond_num_score = self.cond_num_out(K_cond_num)
+
+        #Predict the columns of conditions
+        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_col_name_enc)
+        h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len)
+
+        if self.use_ca:
+            col_att_val = torch.bmm(e_cond_col,
+                    self.cond_col_att(h_col_enc).transpose(1, 2))
+            for idx, num in enumerate(x_len):
+                if num < max_x_len:
+                    col_att_val[idx, :, num:] = -100
+            col_att = self.softmax(col_att_val.view(
+                (-1, max_x_len))).view(B, -1, max_x_len)
+            K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2)
+        else:
+            col_att_val = self.cond_col_att(h_col_enc).squeeze()
+            for idx, num in enumerate(x_len):
+                if num < max_x_len:
+                    col_att_val[idx, num:] = -100
+            col_att = self.softmax(col_att_val)
+            K_cond_col = (h_col_enc *
+                    col_att_val.unsqueeze(2)).sum(1).unsqueeze(1)
+
+        cond_col_score = self.cond_col_out(self.cond_col_out_K(K_cond_col) +
+                self.cond_col_out_col(e_cond_col)).squeeze()
+        max_col_num = max(col_num)
+        for b, num in enumerate(col_num):
+            if num < max_col_num:
+                cond_col_score[b, num:] = -100
+
+
+        #Predict the operator of conditions
+        chosen_col_gt = []
+        if gt_cond is None:
+            cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1)
+            col_scores = cond_col_score.data.cpu().numpy()
+            chosen_col_gt = [list(np.argsort(-col_scores[b])[:cond_nums[b]])
+                    for b in range(len(cond_nums))]
+        else:
+            # print gt_cond
+            chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond]
+
+        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
+                col_len, self.cond_op_name_enc)
+        h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
+        col_emb = []
+        for b in range(B):
+            cur_col_emb = torch.stack([e_cond_col[b, x] 
+                for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] *
+                (4 - len(chosen_col_gt[b])))  # Pad the columns to maximum (4)
+            col_emb.append(cur_col_emb)
+        col_emb = torch.stack(col_emb)
+
+        if self.use_ca:
+            op_att_val = torch.matmul(self.cond_op_att(h_op_enc).unsqueeze(1),
+                    col_emb.unsqueeze(3)).squeeze()
+            for idx, num in enumerate(x_len):
+                if num < max_x_len:
+                    op_att_val[idx, :, num:] = -100
+            op_att = self.softmax(op_att_val.view(B*4, -1)).view(B, 4, -1)
+            K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)
+        else:
+            op_att_val = self.cond_op_att(h_op_enc).squeeze()
+            for idx, num in enumerate(x_len):
+                if num < max_x_len:
+                    op_att_val[idx, num:] = -100
+            op_att = self.softmax(op_att_val)
+            K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1)
+
+        cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) +
+                self.cond_op_out_col(col_emb)).squeeze()
+
+        #Predict the string of conditions
+        h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len)
+        e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
+                col_len, self.cond_str_name_enc)
+        col_emb = []
+        for b in range(B):
+            cur_col_emb = torch.stack([e_cond_col[b, x] for x in chosen_col_gt[b]] +
+                                      [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))
+            col_emb.append(cur_col_emb)
+        col_emb = torch.stack(col_emb)
+
+        if gt_where is not None:
+            gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where)
+            g_str_s_flat, _ = self.cond_str_decoder(
+                    gt_tok_seq.view(B*4, -1, self.max_tok_num))
+            g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h)
+
+            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
+            g_ext = g_str_s.unsqueeze(3)
+            col_ext = col_emb.unsqueeze(2).unsqueeze(2)
+
+            cond_str_score = self.cond_str_out(
+                    self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
+                    self.cond_str_out_col(col_ext)).squeeze()
+            for b, num in enumerate(x_len):
+                if num < max_x_len:
+                    cond_str_score[b, :, :, num:] = -100
+        else:
+            h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
+            col_ext = col_emb.unsqueeze(2).unsqueeze(2)
+            scores = []
+
+            t = 0
+            init_inp = np.zeros((B*4, 1, self.max_tok_num), dtype=np.float32)
+            init_inp[:,0,0] = 1  #Set the <BEG> token
+            if self.gpu:
+                cur_inp = Variable(torch.from_numpy(init_inp).cuda())
+            else:
+                cur_inp = Variable(torch.from_numpy(init_inp))
+            cur_h = None
+            while t < 50:
+                if cur_h:
+                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h)
+                else:
+                    g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp)
+                g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h)
+                g_ext = g_str_s.unsqueeze(3)
+
+                cur_cond_str_score = self.cond_str_out(
+                        self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext)
+                        + self.cond_str_out_col(col_ext)).squeeze()
+                for b, num in enumerate(x_len):
+                    if num < max_x_len:
+                        cur_cond_str_score[b, :, num:] = -100
+                scores.append(cur_cond_str_score)
+
+                _, ans_tok_var = cur_cond_str_score.view(B*4, max_x_len).max(1)
+                ans_tok = ans_tok_var.data.cpu()
+                data = torch.zeros(B*4, self.max_tok_num).scatter_(
+                        1, ans_tok.unsqueeze(1), 1)
+                if self.gpu:  #To one-hot
+                    cur_inp = Variable(data.cuda())
+                else:
+                    cur_inp = Variable(data)
+                cur_inp = cur_inp.unsqueeze(1)
+
+                t += 1
+
+            cond_str_score = torch.stack(scores, 2)
+            for b, num in enumerate(x_len):
+                if num < max_x_len:
+                    cond_str_score[b, :, :, num:] = -100  #[B, IDX, T, TOK_NUM]
+
+        cond_score = (cond_num_score,
+                cond_col_score, cond_op_score, cond_str_score)
+
+        return cond_score

+ 63 - 0
sqlnet/model/modules/where_relation.py

@@ -0,0 +1,63 @@
+import json
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from net_utils import run_lstm, col_name_encode
+
+class WhereRelationPredictor(nn.Module):
+    def __init__(self, N_word, N_h, N_depth, use_ca):
+        super(WhereRelationPredictor, self).__init__()
+        self.N_h = N_h
+        self.use_ca = use_ca
+
+        self.where_rela_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
+                                    num_layers=N_depth, batch_first=True,
+                                    dropout=0.3, bidirectional=True)
+        self.where_rela_att = nn.Linear(N_h, 1)
+        self.where_rela_col_att = nn.Linear(N_h, 1)
+        self.where_rela_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(), nn.Linear(N_h,3))
+        self.softmax = nn.Softmax(dim=-1)
+        self.col2hid1 = nn.Linear(N_h, 2 * N_h)
+        self.col2hid2 = nn.Linear(N_h, 2 * N_h)
+
+        if self.use_ca:
+            print "Using column attention on where relation predicting"
+
+    def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num):
+        B = len(x_len)
+        max_x_len = max(x_len)
+
+        # Predict the condition relationship part
+        # First use column embeddings to calculate the initial hidden unit
+        # Then run the LSTM and predict select number
+        e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
+                                             col_len, self.where_rela_lstm)
+        col_att_val = self.where_rela_col_att(e_num_col).squeeze()
+        for idx, num in enumerate(col_num):
+            if num < max(col_num):
+                col_att_val[idx, num:] = -1000000
+        num_col_att = self.softmax(col_att_val)
+        K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
+        h1 = self.col2hid1(K_num_col).view(B, 4, self.N_h/2).transpose(0,1).contiguous()
+        h2 = self.col2hid2(K_num_col).view(B, 4, self.N_h/2).transpose(0,1).contiguous()
+
+        h_enc, _ = run_lstm(self.where_rela_lstm, x_emb_var, x_len, hidden=(h1, h2))
+
+        att_val = self.where_rela_att(h_enc).squeeze()
+        for idx, num in enumerate(x_len):
+            if num < max_x_len:
+                att_val[idx, num:] = -1000000
+        att_val = self.softmax(att_val)
+
+        where_rela_num = (h_enc * att_val.unsqueeze(2).expand_as(h_enc)).sum(1)
+        where_rela_score = self.where_rela_out(where_rela_num)
+        return where_rela_score
+
+
+
+
+
+
+

+ 123 - 0
sqlnet/model/modules/word_embedding.py

@@ -0,0 +1,123 @@
+import json
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+
+class WordEmbedding(nn.Module):
+    def __init__(self, word_emb, N_word, gpu, SQL_TOK, our_model, trainable=False):
+        super(WordEmbedding, self).__init__()
+        self.trainable = trainable
+        self.N_word = N_word
+        self.our_model = our_model
+        self.gpu = gpu
+        self.SQL_TOK = SQL_TOK
+
+        if trainable:
+            print "Using trainable embedding"
+            self.w2i, word_emb_val = word_emb
+            self.embedding = nn.Embedding(len(self.w2i), N_word)
+            self.embedding.weight = nn.Parameter(
+                    torch.from_numpy(word_emb_val.astype(np.float32)))
+        else:
+            self.word_emb = word_emb
+            print "Using fixed embedding"
+
+
+    def gen_x_batch(self, q, col):
+        B = len(q)
+        val_embs = []
+        val_len = np.zeros(B, dtype=np.int64)
+        for i, (one_q, one_col) in enumerate(zip(q, col)):
+            if self.trainable:
+                q_val = map(lambda x:self.w2i.get(x, 0), one_q)
+            else:
+                q_val = map(lambda x:self.word_emb.get(x, np.zeros(self.N_word, dtype=np.float32)), one_q)
+            if self.our_model:
+                if self.trainable:
+                    val_embs.append([1] + q_val + [2])  #<BEG> and <END>
+                else:
+                    val_embs.append([np.zeros(self.N_word, dtype=np.float32)] + q_val + [np.zeros(self.N_word, dtype=np.float32)])  #<BEG> and <END>
+                val_len[i] = 1 + len(q_val) + 1
+            else:
+                one_col_all = [x for toks in one_col for x in toks+[',']]
+                if self.trainable:
+                    col_val = map(lambda x:self.w2i.get(x, 0), one_col_all)
+                    val_embs.append( [0 for _ in self.SQL_TOK] + col_val + [0] + q_val+ [0])
+                else:
+                    col_val = map(lambda x:self.word_emb.get(x, np.zeros(self.N_word, dtype=np.float32)), one_col_all)
+                    val_embs.append( [np.zeros(self.N_word, dtype=np.float32) for _ in self.SQL_TOK] + col_val + [np.zeros(self.N_word, dtype=np.float32)] + q_val+ [np.zeros(self.N_word, dtype=np.float32)])
+                val_len[i] = len(self.SQL_TOK) + len(col_val) + 1 + len(q_val) + 1
+        max_len = max(val_len)
+
+        if self.trainable:
+            val_tok_array = np.zeros((B, max_len), dtype=np.int64)
+            for i in range(B):
+                for t in range(len(val_embs[i])):
+                    val_tok_array[i,t] = val_embs[i][t]
+            val_tok = torch.from_numpy(val_tok_array)
+            if self.gpu:
+                val_tok = val_tok.cuda()
+            val_tok_var = Variable(val_tok)
+            val_inp_var = self.embedding(val_tok_var)
+        else:
+            val_emb_array = np.zeros((B, max_len, self.N_word), dtype=np.float32)
+            for i in range(B):
+                for t in range(len(val_embs[i])):
+                    val_emb_array[i,t,:] = val_embs[i][t]
+            val_inp = torch.from_numpy(val_emb_array)
+            if self.gpu:
+                val_inp = val_inp.cuda()
+            val_inp_var = Variable(val_inp)
+        return val_inp_var, val_len
+
+    def gen_col_batch(self, cols):
+        ret = []
+        col_len = np.zeros(len(cols), dtype=np.int64)
+
+        names = []
+        for b, one_cols in enumerate(cols):
+            names = names + one_cols
+            col_len[b] = len(one_cols)
+
+        name_inp_var, name_len = self.str_list_to_batch(names)
+        return name_inp_var, name_len, col_len
+
+    def str_list_to_batch(self, str_list):
+        B = len(str_list)
+
+        val_embs = []
+        val_len = np.zeros(B, dtype=np.int64)
+        for i, one_str in enumerate(str_list):
+            if self.trainable:
+                val = [self.w2i.get(x, 0) for x in one_str]
+            else:
+                val = [self.word_emb.get(x, np.zeros(
+                    self.N_word, dtype=np.float32)) for x in one_str]
+            val_embs.append(val)
+            val_len[i] = len(val)
+        max_len = max(val_len)
+
+        if self.trainable:
+            val_tok_array = np.zeros((B, max_len), dtype=np.int64)
+            for i in range(B):
+                for t in range(len(val_embs[i])):
+                    val_tok_array[i,t] = val_embs[i][t]
+            val_tok = torch.from_numpy(val_tok_array)
+            if self.gpu:
+                val_tok = val_tok.cuda()
+            val_tok_var = Variable(val_tok)
+            val_inp_var = self.embedding(val_tok_var)
+        else:
+            val_emb_array = np.zeros(
+                    (B, max_len, self.N_word), dtype=np.float32)
+            for i in range(B):
+                for t in range(len(val_embs[i])):
+                    val_emb_array[i,t,:] = val_embs[i][t]
+            val_inp = torch.from_numpy(val_emb_array)
+            if self.gpu:
+                val_inp = val_inp.cuda()
+            val_inp_var = Variable(val_inp)
+
+        return val_inp_var, val_len

+ 416 - 0
sqlnet/model/sqlnet.py

@@ -0,0 +1,416 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from modules.word_embedding import WordEmbedding
+from modules.aggregator_predict import AggPredictor
+from modules.selection_predict import SelPredictor
+from modules.sqlnet_condition_predict import SQLNetCondPredictor
+from modules.select_number import SelNumPredictor
+from modules.where_relation import WhereRelationPredictor
+
+
+class SQLNet(nn.Module):
+    def __init__(self, word_emb, N_word, N_h=100, N_depth=2,
+            gpu=False, use_ca=True, trainable_emb=False):
+        super(SQLNet, self).__init__()
+        self.use_ca = use_ca
+        self.trainable_emb = trainable_emb
+
+        self.gpu = gpu
+        self.N_h = N_h
+        self.N_depth = N_depth
+
+        self.max_col_num = 45
+        self.max_tok_num = 200
+        self.SQL_TOK = ['<UNK>', '<END>', 'WHERE', 'AND', 'OR', '==', '>', '<', '!=', '<BEG>']
+        self.COND_OPS = ['>', '<', '==', '!=']
+
+        # Word embedding
+        self.embed_layer = WordEmbedding(word_emb, N_word, gpu, self.SQL_TOK, our_model=True, trainable=trainable_emb)
+
+        # Predict the number of selected columns
+        self.sel_num = SelNumPredictor(N_word, N_h, N_depth, use_ca=use_ca)
+
+        #Predict which columns are selected
+        self.sel_pred = SelPredictor(N_word, N_h, N_depth, self.max_tok_num, use_ca=use_ca)
+
+        #Predict aggregation functions of corresponding selected columns
+        self.agg_pred = AggPredictor(N_word, N_h, N_depth, use_ca=use_ca)
+
+        #Predict number of conditions, condition columns, condition operations and condition values
+        self.cond_pred = SQLNetCondPredictor(N_word, N_h, N_depth, self.max_col_num, self.max_tok_num, use_ca, gpu)
+
+        # Predict condition relationship, like 'and', 'or'
+        self.where_rela_pred = WhereRelationPredictor(N_word, N_h, N_depth, use_ca=use_ca)
+
+
+        self.CE = nn.CrossEntropyLoss()
+        self.softmax = nn.Softmax(dim=-1)
+        self.log_softmax = nn.LogSoftmax()
+        self.bce_logit = nn.BCEWithLogitsLoss()
+        if gpu:
+            self.cuda()
+
+    def generate_gt_where_seq_test(self, q, gt_cond_seq):
+        ret_seq = []
+        for cur_q, ans in zip(q, gt_cond_seq):
+            temp_q = u"".join(cur_q)
+            cur_q = [u'<BEG>'] + cur_q + [u'<END>']
+            record = []
+            record_cond = []
+            for cond in ans:
+                if cond[2] not in temp_q:
+                    record.append((False, cond[2]))
+                else:
+                    record.append((True, cond[2]))
+            for idx, item in enumerate(record):
+                temp_ret_seq = []
+                if item[0]:
+                    temp_ret_seq.append(0)
+                    temp_ret_seq.extend(list(range(temp_q.index(item[1])+1,temp_q.index(item[1])+len(item[1])+1)))
+                    temp_ret_seq.append(len(cur_q)-1)
+                else:
+                    temp_ret_seq.append([0,len(cur_q)-1])
+                record_cond.append(temp_ret_seq)
+            ret_seq.append(record_cond)
+        return ret_seq
+
+    def forward(self, q, col, col_num, gt_where = None, gt_cond=None, reinforce=False, gt_sel=None, gt_sel_num=None):
+        B = len(q)
+
+        sel_num_score = None
+        agg_score = None
+        sel_score = None
+        cond_score = None
+        #Predict aggregator
+        if self.trainable_emb:
+            x_emb_var, x_len = self.agg_embed_layer.gen_x_batch(q, col)
+            col_inp_var, col_name_len, col_len = self.agg_embed_layer.gen_col_batch(col)
+            max_x_len = max(x_len)
+            agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var,
+                    col_name_len, col_len, col_num, gt_sel=gt_sel)
+
+            x_emb_var, x_len = self.sel_embed_layer.gen_x_batch(q, col)
+            col_inp_var, col_name_len, col_len = self.sel_embed_layer.gen_col_batch(col)
+            max_x_len = max(x_len)
+            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var,
+                    col_name_len, col_len, col_num)
+
+            x_emb_var, x_len = self.cond_embed_layer.gen_x_batch(q, col)
+            col_inp_var, col_name_len, col_len = self.cond_embed_layer.gen_col_batch(col)
+            max_x_len = max(x_len)
+            cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce)
+            where_rela_score = None
+        else:
+            x_emb_var, x_len = self.embed_layer.gen_x_batch(q, col)
+            col_inp_var, col_name_len, col_len = self.embed_layer.gen_col_batch(col)
+            sel_num_score = self.sel_num(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
+            # x_emb_var: embedding of each question
+            # x_len: length of each question
+            # col_inp_var: embedding of each header
+            # col_name_len: length of each header
+            # col_len: number of headers in each table, array type
+            # col_num: number of headers in each table, list type
+            if gt_sel_num:
+                pr_sel_num = gt_sel_num
+            else:
+                pr_sel_num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
+            sel_score = self.sel_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
+
+            if gt_sel:
+                pr_sel = gt_sel
+            else:
+                num = np.argmax(sel_num_score.data.cpu().numpy(), axis=1)
+                sel = sel_score.data.cpu().numpy()
+                pr_sel = [list(np.argsort(-sel[b])[:num[b]]) for b in range(len(num))]
+            agg_score = self.agg_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_sel=pr_sel, gt_sel_num=pr_sel_num)
+
+            where_rela_score = self.where_rela_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num)
+
+            cond_score = self.cond_pred(x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num, gt_where, gt_cond, reinforce=reinforce)
+
+        return (sel_num_score, sel_score, agg_score, cond_score, where_rela_score)
+
+    def loss(self, score, truth_num, gt_where):
+        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
+
+        B = len(truth_num)
+        loss = 0
+
+        # Evaluate select number
+        sel_num_truth = map(lambda x:x[0], truth_num)
+        sel_num_truth = torch.from_numpy(np.array(sel_num_truth))
+        if self.gpu:
+            sel_num_truth = Variable(sel_num_truth.cuda())
+        else:
+            sel_num_truth = Variable(sel_num_truth)
+        loss += self.CE(sel_num_score, sel_num_truth)
+
+        # Evaluate select column
+        T = len(sel_score[0])
+        truth_prob = np.zeros((B,T), dtype=np.float32)
+        for b in range(B):
+            truth_prob[b][list(truth_num[b][1])] = 1
+        data = torch.from_numpy(truth_prob)
+        if self.gpu:
+            sel_col_truth_var = Variable(data.cuda())
+        else:
+            sel_col_truth_var = Variable(data)
+        sigm = nn.Sigmoid()
+        sel_col_prob = sigm(sel_score)
+        bce_loss = -torch.mean(
+            3*(sel_col_truth_var * torch.log(sel_col_prob+1e-10)) +
+            (1-sel_col_truth_var) * torch.log(1-sel_col_prob+1e-10)
+        )
+        loss += bce_loss
+
+        # Evaluate select aggregation
+        for b in range(len(truth_num)):
+            data = torch.from_numpy(np.array(truth_num[b][2]))
+            if self.gpu:
+                sel_agg_truth_var = Variable(data.cuda())
+            else:
+                sel_agg_truth_var = Variable(data)
+        sel_agg_pred = agg_score[b, :len(truth_num[b][1])]
+        loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)
+
+        cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score
+
+        # Evaluate the number of conditions
+        cond_num_truth = map(lambda x:x[3], truth_num)
+        data = torch.from_numpy(np.array(cond_num_truth))
+        if self.gpu:
+            try:
+                cond_num_truth_var = Variable(data.cuda())
+            except:
+                print "cond_num_truth_var error"
+                print data
+                exit(0)
+        else:
+            cond_num_truth_var = Variable(data)
+        loss += self.CE(cond_num_score, cond_num_truth_var)
+
+        # Evaluate the columns of conditions
+        T = len(cond_col_score[0])
+        truth_prob = np.zeros((B, T), dtype=np.float32)
+        for b in range(B):
+            if len(truth_num[b][4]) > 0:
+                truth_prob[b][list(truth_num[b][4])] = 1
+        data = torch.from_numpy(truth_prob)
+        if self.gpu:
+            cond_col_truth_var = Variable(data.cuda())
+        else:
+            cond_col_truth_var = Variable(data)
+
+        sigm = nn.Sigmoid()
+        cond_col_prob = sigm(cond_col_score)
+        bce_loss = -torch.mean(
+            3*(cond_col_truth_var * torch.log(cond_col_prob+1e-10)) +
+            (1-cond_col_truth_var) * torch.log(1-cond_col_prob+1e-10) )
+        loss += bce_loss
+
+        # Evaluate the operator of conditions
+        for b in range(len(truth_num)):
+            if len(truth_num[b][5]) == 0:
+                continue
+            data = torch.from_numpy(np.array(truth_num[b][5]))
+            if self.gpu:
+                cond_op_truth_var = Variable(data.cuda())
+            else:
+                cond_op_truth_var = Variable(data)
+            cond_op_pred = cond_op_score[b, :len(truth_num[b][5])]
+            try:
+                loss += (self.CE(cond_op_pred, cond_op_truth_var) / len(truth_num))
+            except:
+                print cond_op_pred
+                print cond_op_truth_var
+                exit(0)
+
+        #Evaluate the strings of conditions
+        for b in range(len(gt_where)):
+            for idx in range(len(gt_where[b])):
+                cond_str_truth = gt_where[b][idx]
+                if len(cond_str_truth) == 1:
+                    continue
+                data = torch.from_numpy(np.array(cond_str_truth[1:]))
+                if self.gpu:
+                    cond_str_truth_var = Variable(data.cuda())
+                else:
+                    cond_str_truth_var = Variable(data)
+                str_end = len(cond_str_truth)-1
+                cond_str_pred = cond_str_score[b, idx, :str_end]
+                loss += (self.CE(cond_str_pred, cond_str_truth_var) \
+                        / (len(gt_where) * len(gt_where[b])))
+
+        # Evaluate condition relationship, and / or
+        where_rela_truth = map(lambda x:x[6], truth_num)
+        data = torch.from_numpy(np.array(where_rela_truth))
+        if self.gpu:
+            try:
+                where_rela_truth = Variable(data.cuda())
+            except:
+                print "where_rela_truth error"
+                print data
+                exit(0)
+        else:
+            where_rela_truth = Variable(data)
+        loss += self.CE(where_rela_score, where_rela_truth)
+        return loss
+
+    def check_acc(self, vis_info, pred_queries, gt_queries):
+        def gen_cond_str(conds, header):
+            if len(conds) == 0:
+                return 'None'
+            cond_str = []
+            for cond in conds:
+                cond_str.append(header[cond[0]] + ' ' +
+                    self.COND_OPS[cond[1]] + ' ' + unicode(cond[2]).lower())
+            return 'WHERE ' + ' AND '.join(cond_str)
+
+        tot_err = sel_num_err = agg_err = sel_err = 0.0
+        cond_num_err = cond_col_err = cond_op_err = cond_val_err = cond_rela_err = 0.0
+        for b, (pred_qry, gt_qry) in enumerate(zip(pred_queries, gt_queries)):
+            good = True
+            sel_pred, agg_pred, where_rela_pred = pred_qry['sel'], pred_qry['agg'], pred_qry['cond_conn_op']
+            sel_gt, agg_gt, where_rela_gt = gt_qry['sel'], gt_qry['agg'], gt_qry['cond_conn_op']
+
+            if where_rela_gt != where_rela_pred:
+                good = False
+                cond_rela_err += 1
+
+            if len(sel_pred) != len(sel_gt):
+                good = False
+                sel_num_err += 1
+
+            pred_sel_dict = {k:v for k,v in zip(list(sel_pred), list(agg_pred))}
+            gt_sel_dict = {k:v for k,v in zip(sel_gt, agg_gt)}
+            if set(sel_pred) != set(sel_gt):
+                good = False
+                sel_err += 1
+            agg_pred = [pred_sel_dict[x] for x in sorted(pred_sel_dict.keys())]
+            agg_gt = [gt_sel_dict[x] for x in sorted(gt_sel_dict.keys())]
+            if agg_pred != agg_gt:
+                good = False
+                agg_err += 1
+
+            cond_pred = pred_qry['conds']
+            cond_gt = gt_qry['conds']
+            if len(cond_pred) != len(cond_gt):
+                good = False
+                cond_num_err += 1
+            else:
+                cond_op_pred, cond_op_gt = {}, {}
+                cond_val_pred, cond_val_gt = {}, {}
+                for p, g in zip(cond_pred, cond_gt):
+                    cond_op_pred[p[0]] = p[1]
+                    cond_val_pred[p[0]] = p[2]
+                    cond_op_gt[g[0]] = g[1]
+                    cond_val_gt[g[0]] = g[2]
+
+                if set(cond_op_pred.keys()) != set(cond_op_gt.keys()):
+                    cond_col_err += 1
+                    good=False
+
+                where_op_pred = [cond_op_pred[x] for x in sorted(cond_op_pred.keys())]
+                where_op_gt = [cond_op_gt[x] for x in sorted(cond_op_gt.keys())]
+                if where_op_pred != where_op_gt:
+                    cond_op_err += 1
+                    good=False
+
+                where_val_pred = [cond_val_pred[x] for x in sorted(cond_val_pred.keys())]
+                where_val_gt = [cond_val_gt[x] for x in sorted(cond_val_gt.keys())]
+                if where_val_pred != where_val_gt:
+                    cond_val_err += 1
+                    good=False
+
+            if not good:
+                tot_err += 1
+
+        return np.array((sel_num_err, sel_err, agg_err, cond_num_err, cond_col_err, cond_op_err, cond_val_err , cond_rela_err)), tot_err
+
+
+    def gen_query(self, score, q, col, raw_q, reinforce=False, verbose=False):
+        """
+        :param score:
+        :param q: token-questions
+        :param col: token-headers
+        :param raw_q: original question sequence
+        :return:
+        """
+        def merge_tokens(tok_list, raw_tok_str):
+            tok_str = raw_tok_str# .lower()
+            alphabet = 'abcdefghijklmnopqrstuvwxyz0123456789$('
+            special = {'-LRB-':'(',
+                    '-RRB-':')',
+                    '-LSB-':'[',
+                    '-RSB-':']',
+                    '``':'"',
+                    '\'\'':'"',
+                    '--':u'\u2013'}
+            ret = ''
+            double_quote_appear = 0
+            for raw_tok in tok_list:
+                if not raw_tok:
+                    continue
+                tok = special.get(raw_tok, raw_tok)
+                if tok == '"':
+                    double_quote_appear = 1 - double_quote_appear
+                if len(ret) == 0:
+                    pass
+                elif len(ret) > 0 and ret + ' ' + tok in tok_str:
+                    ret = ret + ' '
+                elif len(ret) > 0 and ret + tok in tok_str:
+                    pass
+                elif tok == '"':
+                    if double_quote_appear:
+                        ret = ret + ' '
+                # elif tok[0] not in alphabet:
+                #     pass
+                elif (ret[-1] not in ['(', '/', u'\u2013', '#', '$', '&']) \
+                        and (ret[-1] != '"' or not double_quote_appear):
+                    ret = ret + ' '
+                ret = ret + tok
+            return ret.strip()
+
+        sel_num_score, sel_score, agg_score, cond_score, where_rela_score = score
+        # [64,4,6], [64,14], ..., [64,4]
+        sel_num_score = sel_num_score.data.cpu().numpy()
+        sel_score = sel_score.data.cpu().numpy()
+        agg_score = agg_score.data.cpu().numpy()
+        where_rela_score = where_rela_score.data.cpu().numpy()
+        ret_queries = []
+        B = len(agg_score)
+        cond_num_score,cond_col_score,cond_op_score,cond_str_score =\
+            [x.data.cpu().numpy() for x in cond_score]
+        for b in range(B):
+            cur_query = {}
+            cur_query['sel'] = []
+            cur_query['agg'] = []
+            sel_num = np.argmax(sel_num_score[b])
+            max_col_idxes = np.argsort(-sel_score[b])[:sel_num]
+            # find the most-probable columns' indexes
+            max_agg_idxes = np.argsort(-agg_score[b])[:sel_num]
+            cur_query['sel'].extend([int(i) for i in max_col_idxes])
+            cur_query['agg'].extend([i[0] for i in max_agg_idxes])
+            cur_query['cond_conn_op'] = np.argmax(where_rela_score[b])
+            cur_query['conds'] = []
+            cond_num = np.argmax(cond_num_score[b])
+            all_toks = ['<BEG>'] + q[b] + ['<END>']
+            max_idxes = np.argsort(-cond_col_score[b])[:cond_num]
+            for idx in range(cond_num):
+                cur_cond = []
+                cur_cond.append(max_idxes[idx]) # where-col
+                cur_cond.append(np.argmax(cond_op_score[b][idx])) # where-op
+                cur_cond_str_toks = []
+                for str_score in cond_str_score[b][idx]:
+                    str_tok = np.argmax(str_score[:len(all_toks)])
+                    str_val = all_toks[str_tok]
+                    if str_val == '<END>':
+                        break
+                    cur_cond_str_toks.append(str_val)
+                cur_cond.append(merge_tokens(cur_cond_str_toks, raw_q[b]))
+                cur_query['conds'].append(cur_cond)
+            ret_queries.append(cur_query)
+        return ret_queries

+ 217 - 0
sqlnet/utils.py

@@ -0,0 +1,217 @@
+import json
+from lib.dbengine import DBEngine
+import numpy as np
+from tqdm import tqdm
+
+def load_data(sql_paths, table_paths, use_small=False):
+    if not isinstance(sql_paths, list):
+        sql_paths = (sql_paths, )
+    if not isinstance(table_paths, list):
+        table_paths = (table_paths, )
+    sql_data = []
+    table_data = {}
+
+    for SQL_PATH in sql_paths:
+        print "Loading data from %s" % SQL_PATH
+        with open(SQL_PATH) as inf:
+            for idx, line in enumerate(inf):
+                sql = json.loads(line.strip())
+                if use_small and idx >= 1000:
+                    break
+                sql_data.append(sql)
+
+    for TABLE_PATH in table_paths:
+        print "Loading data from %s" % TABLE_PATH
+        with open(TABLE_PATH) as inf:
+            for line in inf:
+                tab = json.loads(line.strip())
+                table_data[tab[u'id']] = tab
+
+    ret_sql_data = []
+    for sql in sql_data:
+        if sql[u'table_id'] in table_data:
+            ret_sql_data.append(sql)
+
+    return ret_sql_data, table_data
+
+def load_dataset(toy=False, use_small=False, mode='train'):
+    print "Loading dataset"
+    dev_sql, dev_table = load_data('data/dev.json', 'data/dev.tables.json', use_small=use_small)
+    dev_db = 'data/dev.db'
+    if mode == 'train':
+        train_sql, train_table = load_data('data/train.json', 'data/train.tables.json', use_small=use_small)
+        train_db = 'data/train.db'
+        return train_sql, train_table, train_db, dev_sql, dev_table, dev_db
+    elif mode == 'test':
+        test_sql, test_table = load_data('data/test.json', 'data/test.tables.json', use_small=use_small)
+        test_db = 'data/test.db'
+        return dev_sql, dev_table, dev_db, test_sql, test_table, test_db
+
+def to_batch_seq(sql_data, table_data, idxes, st, ed, ret_vis_data=False):
+    q_seq = []
+    col_seq = []
+    col_num = []
+    ans_seq = []
+    gt_cond_seq = []
+    vis_seq = []
+    sel_num_seq = []
+    for i in range(st, ed):
+        sql = sql_data[idxes[i]]
+        sel_num = len(sql['sql']['sel'])
+        sel_num_seq.append(sel_num)
+        conds_num = len(sql['sql']['conds'])
+        q_seq.append([char for char in sql['question']])
+        col_seq.append([[char for char in header] for header in table_data[sql['table_id']]['header']])
+        col_num.append(len(table_data[sql['table_id']]['header']))
+        ans_seq.append(
+            (
+            len(sql['sql']['agg']),
+            sql['sql']['sel'],
+            sql['sql']['agg'],
+            conds_num,
+            tuple(x[0] for x in sql['sql']['conds']),
+            tuple(x[1] for x in sql['sql']['conds']),
+            sql['sql']['cond_conn_op'],
+            ))
+        gt_cond_seq.append(sql['sql']['conds'])
+        vis_seq.append((sql['question'], table_data[sql['table_id']]['header']))
+    if ret_vis_data:
+        return q_seq, sel_num_seq, col_seq, col_num, ans_seq, gt_cond_seq, vis_seq
+    else:
+        return q_seq, sel_num_seq, col_seq, col_num, ans_seq, gt_cond_seq
+
+def to_batch_seq_test(sql_data, table_data, idxes, st, ed):
+    q_seq = []
+    col_seq = []
+    col_num = []
+    raw_seq = []
+    table_ids = []
+    for i in range(st, ed):
+        sql = sql_data[idxes[i]]
+        q_seq.append([char for char in sql['question']])
+        col_seq.append([[char for char in header] for header in table_data[sql['table_id']]['header']])
+        col_num.append(len(table_data[sql['table_id']]['header']))
+        raw_seq.append(sql['question'])
+        table_ids.append(sql_data[idxes[i]]['table_id'])
+    return q_seq, col_seq, col_num, raw_seq, table_ids
+
+def to_batch_query(sql_data, idxes, st, ed):
+    query_gt = []
+    table_ids = []
+    for i in range(st, ed):
+        sql_data[idxes[i]]['sql']['conds'] = sql_data[idxes[i]]['sql']['conds']
+        query_gt.append(sql_data[idxes[i]]['sql'])
+        table_ids.append(sql_data[idxes[i]]['table_id'])
+    return query_gt, table_ids
+
+def epoch_train(model, optimizer, batch_size, sql_data, table_data):
+    model.train()
+    perm=np.random.permutation(len(sql_data))
+    cum_loss = 0.0
+    for st in tqdm(range(len(sql_data)//batch_size+1)):
+        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
+        st = st * batch_size
+        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq = to_batch_seq(sql_data, table_data, perm, st, ed)
+        # q_seq: char-based sequence of question
+        # gt_sel_num: number of selected columns and aggregation functions
+        # col_seq: char-based column name
+        # col_num: number of headers in one table
+        # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
+        # gt_cond_seq: ground truth of conds
+        gt_where_seq = model.generate_gt_where_seq_test(q_seq, gt_cond_seq)
+        gt_sel_seq = [x[1] for x in ans_seq]
+        score = model.forward(q_seq, col_seq, col_num, gt_where=gt_where_seq, gt_cond=gt_cond_seq, gt_sel=gt_sel_seq, gt_sel_num=gt_sel_num)
+        # sel_num_score, sel_col_score, sel_agg_score, cond_score, cond_rela_score
+
+        # compute loss
+        loss = model.loss(score, ans_seq, gt_where_seq)
+        cum_loss += loss.data.cpu().numpy()*(ed - st)
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+    return cum_loss / len(sql_data)
+
+def epoch_exec_acc(model, batch_size, sql_data, table_data, db_path):
+    engine = DBEngine(db_path)
+    model.eval()
+    perm = list(range(len(sql_data)))
+    tot_acc_num = 0.0
+    for st in tqdm(range(len(sql_data)//batch_size+1)):
+        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
+        st = st * batch_size
+        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
+            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
+        raw_q_seq = [x[0] for x in raw_data]
+        raw_col_seq = [x[1] for x in raw_data]
+        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
+        gt_sel_seq = [x[2] for x in ans_seq]
+        score = model.forward(q_seq, col_seq, col_num, gt_sel=gt_sel_seq)
+        pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq, raw_col_seq)
+
+        for idx, (sql_gt, sql_pred, tid) in enumerate(zip(query_gt, pred_queries, table_ids)):
+            ret_gt = engine.execute(tid, sql_gt['sel'], sql_gt['agg'], sql_gt['conds'])
+            try:
+                ret_pred = engine.execute(tid, sql_pred['sel'], sql_pred['agg'], sql_pred['conds'])
+            except:
+                ret_pred = None
+            tot_acc_num += (ret_gt == ret_pred)
+    return tot_acc_num / len(sql_data)
+
+def predict_test(model, batch_size, sql_data, table_data, db_path, output_path):
+    engine = DBEngine(db_path)
+    model.eval()
+    perm = list(range(len(sql_data)))
+    fw = open(output_path,'w')
+    for st in tqdm(range(len(sql_data)//batch_size+1)):
+        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
+        st = st * batch_size
+        q_seq, col_seq, col_num, raw_q_seq, table_ids = to_batch_seq_test(sql_data, table_data, perm, st, ed)
+        score = model.forward(q_seq, col_seq, col_num)
+        sql_preds = model.gen_query(score, q_seq, col_seq, raw_q_seq)
+        for sql_pred in sql_preds:
+            fw.writelines(json.dumps(sql_pred,ensure_ascii=False).encode('utf-8')+'\n')
+    fw.close()
+
+def epoch_acc(model, batch_size, sql_data, table_data):
+    model.eval()
+    perm = list(range(len(sql_data)))
+    badcase = 0
+    one_acc_num, tot_acc_num = 0.0, 0.0
+    for st in tqdm(range(len(sql_data)//batch_size+1)):
+        ed = (st+1)*batch_size if (st+1)*batch_size < len(perm) else len(perm)
+        st = st * batch_size
+        q_seq, gt_sel_num, col_seq, col_num, ans_seq, gt_cond_seq, raw_data = \
+            to_batch_seq(sql_data, table_data, perm, st, ed, ret_vis_data=True)
+        # q_seq: char-based sequence of question
+        # gt_sel_num: number of selected columns and aggregation functions, new added field
+        # col_seq: char-based column name
+        # col_num: number of headers in one table
+        # ans_seq: (sel, number of conds, sel list in conds, op list in conds)
+        # gt_cond_seq: ground truth of conditions
+        # raw_data: ori question, headers, sql
+        query_gt, table_ids = to_batch_query(sql_data, perm, st, ed)
+        # query_gt: ground truth of sql, data['sql'], containing sel, agg, conds:{sel, op, value}
+        raw_q_seq = [x[0] for x in raw_data] # original question
+        try:
+            score = model.forward(q_seq, col_seq, col_num)
+            pred_queries = model.gen_query(score, q_seq, col_seq, raw_q_seq)
+            # generate predicted format
+            one_err, tot_err = model.check_acc(raw_data, pred_queries, query_gt)
+        except:
+            badcase += 1
+            print 'badcase', badcase
+            continue
+        one_acc_num += (ed-st-one_err)
+        tot_acc_num += (ed-st-tot_err)
+    return one_acc_num / len(sql_data), tot_acc_num / len(sql_data),
+
+
+def load_word_emb(file_name):
+    print ('Loading word embedding from %s'%file_name)
+    ret = {}
+    with open(file_name) as inf:
+        for idx, line in enumerate(inf):
+            info = line.strip().split(' ')
+            if info[0].lower() not in ret:
+                ret[info[0].decode('utf-8')] = np.array(map(lambda x:float(x), info[1:]))
+    return ret

+ 2 - 0
start_test.sh

@@ -0,0 +1,2 @@
+#!/usr/bin/bash
+CUDA_VISIBLE_DEVICES=$1 python test.py --ca --gpu --output_dir $2

+ 2 - 0
start_train.sh

@@ -0,0 +1,2 @@
+#!/usr/bin/bash
+CUDA_VISIBLE_DEVICES=$1 python train.py --ca --gpu --bs $2

+ 37 - 0
test.py

@@ -0,0 +1,37 @@
+import torch
+from sqlnet.utils import *
+from sqlnet.model.sqlnet import SQLNet
+import argparse
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--gpu', action='store_true', help='Whether use gpu')
+    parser.add_argument('--toy', action='store_true', help='Small batchsize for fast debugging.')
+    parser.add_argument('--ca', action='store_true', help='Whether use column attention.')
+    parser.add_argument('--train_emb', action='store_true', help='Use trained word embedding for SQLNet.')
+    parser.add_argument('--output_dir', type=str, default='', help='Output path of prediction result')
+    args = parser.parse_args()
+
+    n_word=300
+    if args.toy:
+        use_small=True
+        gpu=args.gpu
+        batch_size=16
+    else:
+        use_small=False
+        gpu=args.gpu
+        batch_size=64
+
+    dev_sql, dev_table, dev_db, test_sql, test_table, test_db = load_dataset(use_small=use_small, mode='test')
+
+    word_emb = load_word_emb('data/char_embedding')
+    model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
+
+    model_path = 'saved_model/best_model'
+    print "Loading from %s" % model_path
+    model.load_state_dict(torch.load(model_path))
+    print "Loaded model from %s" % model_path
+
+    print "Start to predict test set"
+    predict_test(model, batch_size, test_sql, test_table, test_db, args.output_dir)
+    print "Output path of prediction result is %s" % args.output_dir

+ 95 - 0
train.py

@@ -0,0 +1,95 @@
+import torch
+from sqlnet.utils import *
+from sqlnet.model.sqlnet import SQLNet
+
+import argparse
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--bs', type=int, default=16, help='Batch size')
+    parser.add_argument('--epoch', type=int, default=100, help='Epoch number')
+    parser.add_argument('--gpu', action='store_true', help='Whether use gpu to train')
+    parser.add_argument('--toy', action='store_true', help='If set, use small data for fast debugging')
+    parser.add_argument('--ca', action='store_true', help='Whether use column attention')
+    parser.add_argument('--train_emb', action='store_true', help='Train word embedding for SQLNet')
+    parser.add_argument('--restore', action='store_true', help='Whether restore trained model')
+    parser.add_argument('--logdir', type=str, default='', help='Path of save experiment log')
+    args = parser.parse_args()
+
+    n_word=300
+    if args.toy:
+        use_small=True
+        gpu=args.gpu
+        batch_size=16
+    else:
+        use_small=False
+        gpu=args.gpu
+        batch_size=args.bs
+    learning_rate = 1e-3
+
+    # load dataset
+    train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
+
+    word_emb = load_word_emb('data/char_embedding')
+    model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
+    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
+
+    if args.restore:
+        model_path= 'saved_model/best_model'
+        print "Loading trained model from %s" % model_path
+        model.load_state_dict(torch.load(model_path))
+
+    # used to record best score of each sub-task
+    best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv, best_wr = 0, 0, 0, 0, 0, 0, 0, 0
+    best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx, best_wr_idx = 0, 0, 0, 0, 0, 0, 0, 0
+    best_lf, best_lf_idx = 0.0, 0
+
+    print "#"*20+"  Star to Train  " + "#"*20
+    for i in range(args.epoch):
+        print 'Epoch %d'%(i+1)
+        # train on the train dataset
+        train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
+        # evaluate on the dev dataset
+        dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table)
+        # accuracy of each sub-task
+        print 'Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
+            dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7])
+        # save the best model
+        if dev_acc[1] > best_lf:
+            best_lf = dev_acc[1]
+            best_lf_idx = i + 1
+            torch.save(model.state_dict(), 'saved_model/best_model')
+
+        # record the best score of each sub-task
+        if True:
+            if dev_acc[0][0] > best_sn:
+                best_sn = dev_acc[0][0]
+                best_sn_idx = i+1
+            if dev_acc[0][1] > best_sc:
+                best_sc = dev_acc[0][1]
+                best_sc_idx = i+1
+            if dev_acc[0][2] > best_sa:
+                best_sa = dev_acc[0][2]
+                best_sa_idx = i+1
+            if dev_acc[0][3] > best_wn:
+                best_wn = dev_acc[0][3]
+                best_wn_idx = i+1
+            if dev_acc[0][4] > best_wc:
+                best_wc = dev_acc[0][4]
+                best_wc_idx = i+1
+            if dev_acc[0][5] > best_wo:
+                best_wo = dev_acc[0][5]
+                best_wo_idx = i+1
+            if dev_acc[0][6] > best_wv:
+                best_wv = dev_acc[0][6]
+                best_wv_idx = i+1
+            if dev_acc[0][7] > best_wr:
+                best_wr = dev_acc[0][7]
+                best_wr_idx = i+1
+        print 'Train loss = %.3f' % train_loss
+        print 'Dev Logic Form: %.3f' % dev_acc[1]
+        print 'Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx)
+        if (i+1) % 10 == 0:
+            print 'Best val acc: %s\nOn epoch individually %s'%(
+                    (best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
+                    (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx))