|
@@ -4,7 +4,7 @@ import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Variable
|
|
from torch.autograd import Variable
|
|
import numpy as np
|
|
import numpy as np
|
|
-from net_utils import run_lstm, col_name_encode
|
|
|
|
|
|
+from sqlnet.model.modules.net_utils import run_lstm, col_name_encode
|
|
|
|
|
|
class SQLNetCondPredictor(nn.Module):
|
|
class SQLNetCondPredictor(nn.Module):
|
|
def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, use_ca, gpu):
|
|
def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, use_ca, gpu):
|
|
@@ -15,59 +15,41 @@ class SQLNetCondPredictor(nn.Module):
|
|
self.gpu = gpu
|
|
self.gpu = gpu
|
|
self.use_ca = use_ca
|
|
self.use_ca = use_ca
|
|
|
|
|
|
- self.cond_num_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
|
|
|
|
- num_layers=N_depth, batch_first=True,
|
|
|
|
- dropout=0.3, bidirectional=True)
|
|
|
|
|
|
+ self.cond_num_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
|
|
self.cond_num_att = nn.Linear(N_h, 1)
|
|
self.cond_num_att = nn.Linear(N_h, 1)
|
|
self.cond_num_out = nn.Sequential(nn.Linear(N_h, N_h),
|
|
self.cond_num_out = nn.Sequential(nn.Linear(N_h, N_h),
|
|
nn.Tanh(), nn.Linear(N_h, 5))
|
|
nn.Tanh(), nn.Linear(N_h, 5))
|
|
- self.cond_num_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
|
|
|
|
- num_layers=N_depth, batch_first=True,
|
|
|
|
- dropout=0.3, bidirectional=True)
|
|
|
|
|
|
+ self.cond_num_name_enc = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
|
|
self.cond_num_col_att = nn.Linear(N_h, 1)
|
|
self.cond_num_col_att = nn.Linear(N_h, 1)
|
|
self.cond_num_col2hid1 = nn.Linear(N_h, 2*N_h)
|
|
self.cond_num_col2hid1 = nn.Linear(N_h, 2*N_h)
|
|
self.cond_num_col2hid2 = nn.Linear(N_h, 2*N_h)
|
|
self.cond_num_col2hid2 = nn.Linear(N_h, 2*N_h)
|
|
|
|
|
|
- self.cond_col_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
|
|
|
|
- num_layers=N_depth, batch_first=True,
|
|
|
|
- dropout=0.3, bidirectional=True)
|
|
|
|
|
|
+ self.cond_col_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
|
|
if use_ca:
|
|
if use_ca:
|
|
- print "Using column attention on where predicting"
|
|
|
|
|
|
+ print ("Using column attention on where predicting")
|
|
self.cond_col_att = nn.Linear(N_h, N_h)
|
|
self.cond_col_att = nn.Linear(N_h, N_h)
|
|
else:
|
|
else:
|
|
- print "Not using column attention on where predicting"
|
|
|
|
|
|
+ print ("Not using column attention on where predicting")
|
|
self.cond_col_att = nn.Linear(N_h, 1)
|
|
self.cond_col_att = nn.Linear(N_h, 1)
|
|
- self.cond_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
|
|
|
|
- num_layers=N_depth, batch_first=True,
|
|
|
|
- dropout=0.3, bidirectional=True)
|
|
|
|
|
|
+ self.cond_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
|
|
self.cond_col_out_K = nn.Linear(N_h, N_h)
|
|
self.cond_col_out_K = nn.Linear(N_h, N_h)
|
|
self.cond_col_out_col = nn.Linear(N_h, N_h)
|
|
self.cond_col_out_col = nn.Linear(N_h, N_h)
|
|
self.cond_col_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1))
|
|
self.cond_col_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1))
|
|
|
|
|
|
- self.cond_op_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
|
|
|
|
- num_layers=N_depth, batch_first=True,
|
|
|
|
- dropout=0.3, bidirectional=True)
|
|
|
|
|
|
+ self.cond_op_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
|
|
if use_ca:
|
|
if use_ca:
|
|
self.cond_op_att = nn.Linear(N_h, N_h)
|
|
self.cond_op_att = nn.Linear(N_h, N_h)
|
|
else:
|
|
else:
|
|
self.cond_op_att = nn.Linear(N_h, 1)
|
|
self.cond_op_att = nn.Linear(N_h, 1)
|
|
self.cond_op_out_K = nn.Linear(N_h, N_h)
|
|
self.cond_op_out_K = nn.Linear(N_h, N_h)
|
|
- self.cond_op_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
|
|
|
|
- num_layers=N_depth, batch_first=True,
|
|
|
|
- dropout=0.3, bidirectional=True)
|
|
|
|
|
|
+ self.cond_op_name_enc = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
|
|
self.cond_op_out_col = nn.Linear(N_h, N_h)
|
|
self.cond_op_out_col = nn.Linear(N_h, N_h)
|
|
self.cond_op_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(),
|
|
self.cond_op_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(),
|
|
nn.Linear(N_h, 4))
|
|
nn.Linear(N_h, 4))
|
|
|
|
|
|
- self.cond_str_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
|
|
|
|
- num_layers=N_depth, batch_first=True,
|
|
|
|
- dropout=0.3, bidirectional=True)
|
|
|
|
- self.cond_str_decoder = nn.LSTM(input_size=self.max_tok_num,
|
|
|
|
- hidden_size=N_h, num_layers=N_depth,
|
|
|
|
- batch_first=True, dropout=0.3)
|
|
|
|
- self.cond_str_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
|
|
|
|
- num_layers=N_depth, batch_first=True,
|
|
|
|
- dropout=0.3, bidirectional=True)
|
|
|
|
|
|
+ self.cond_str_lstm = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
|
|
|
|
+ self.cond_str_decoder = nn.LSTM(input_size=self.max_tok_num, hidden_size=N_h, num_layers=N_depth, batch_first=True, dropout=0.3)
|
|
|
|
+ self.cond_str_name_enc = nn.LSTM(input_size=N_word, hidden_size=int(N_h/2), num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True)
|
|
self.cond_str_out_g = nn.Linear(N_h, N_h)
|
|
self.cond_str_out_g = nn.Linear(N_h, N_h)
|
|
self.cond_str_out_h = nn.Linear(N_h, N_h)
|
|
self.cond_str_out_h = nn.Linear(N_h, N_h)
|
|
self.cond_str_out_col = nn.Linear(N_h, N_h)
|
|
self.cond_str_out_col = nn.Linear(N_h, N_h)
|
|
@@ -78,7 +60,7 @@ class SQLNetCondPredictor(nn.Module):
|
|
|
|
|
|
def gen_gt_batch(self, split_tok_seq):
|
|
def gen_gt_batch(self, split_tok_seq):
|
|
B = len(split_tok_seq)
|
|
B = len(split_tok_seq)
|
|
- max_len = max([max([len(tok) for tok in tok_seq]+[0]) for
|
|
|
|
|
|
+ max_len = max([max([len(tok) for tok in tok_seq]+[0]) for
|
|
tok_seq in split_tok_seq]) - 1 # The max seq len in the batch.
|
|
tok_seq in split_tok_seq]) - 1 # The max seq len in the batch.
|
|
if max_len < 1:
|
|
if max_len < 1:
|
|
max_len = 1
|
|
max_len = 1
|
|
@@ -121,10 +103,8 @@ class SQLNetCondPredictor(nn.Module):
|
|
num_col_att_val[idx, num:] = -100
|
|
num_col_att_val[idx, num:] = -100
|
|
num_col_att = self.softmax(num_col_att_val)
|
|
num_col_att = self.softmax(num_col_att_val)
|
|
K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
|
|
K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
|
|
- cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(
|
|
|
|
- B, 4, self.N_h/2).transpose(0, 1).contiguous()
|
|
|
|
- cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(
|
|
|
|
- B, 4, self.N_h/2).transpose(0, 1).contiguous()
|
|
|
|
|
|
+ cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(B, 4, self.N_h//2).transpose(0, 1).contiguous()
|
|
|
|
+ cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(B, 4, self.N_h//2).transpose(0, 1).contiguous()
|
|
|
|
|
|
h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len,
|
|
h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len,
|
|
hidden=(cond_num_h1, cond_num_h2))
|
|
hidden=(cond_num_h1, cond_num_h2))
|
|
@@ -185,7 +165,7 @@ class SQLNetCondPredictor(nn.Module):
|
|
h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
|
|
h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
|
|
col_emb = []
|
|
col_emb = []
|
|
for b in range(B):
|
|
for b in range(B):
|
|
- cur_col_emb = torch.stack([e_cond_col[b, x]
|
|
|
|
|
|
+ cur_col_emb = torch.stack([e_cond_col[b, x]
|
|
for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] *
|
|
for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] *
|
|
(4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4)
|
|
(4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4)
|
|
col_emb.append(cur_col_emb)
|
|
col_emb.append(cur_col_emb)
|