sqlnet_condition_predict.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import json
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. import numpy as np
  7. from net_utils import run_lstm, col_name_encode
  8. class SQLNetCondPredictor(nn.Module):
  9. def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, use_ca, gpu):
  10. super(SQLNetCondPredictor, self).__init__()
  11. self.N_h = N_h
  12. self.max_tok_num = max_tok_num
  13. self.max_col_num = max_col_num
  14. self.gpu = gpu
  15. self.use_ca = use_ca
  16. self.cond_num_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  17. num_layers=N_depth, batch_first=True,
  18. dropout=0.3, bidirectional=True)
  19. self.cond_num_att = nn.Linear(N_h, 1)
  20. self.cond_num_out = nn.Sequential(nn.Linear(N_h, N_h),
  21. nn.Tanh(), nn.Linear(N_h, 5))
  22. self.cond_num_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  23. num_layers=N_depth, batch_first=True,
  24. dropout=0.3, bidirectional=True)
  25. self.cond_num_col_att = nn.Linear(N_h, 1)
  26. self.cond_num_col2hid1 = nn.Linear(N_h, 2*N_h)
  27. self.cond_num_col2hid2 = nn.Linear(N_h, 2*N_h)
  28. self.cond_col_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  29. num_layers=N_depth, batch_first=True,
  30. dropout=0.3, bidirectional=True)
  31. if use_ca:
  32. print "Using column attention on where predicting"
  33. self.cond_col_att = nn.Linear(N_h, N_h)
  34. else:
  35. print "Not using column attention on where predicting"
  36. self.cond_col_att = nn.Linear(N_h, 1)
  37. self.cond_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  38. num_layers=N_depth, batch_first=True,
  39. dropout=0.3, bidirectional=True)
  40. self.cond_col_out_K = nn.Linear(N_h, N_h)
  41. self.cond_col_out_col = nn.Linear(N_h, N_h)
  42. self.cond_col_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1))
  43. self.cond_op_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  44. num_layers=N_depth, batch_first=True,
  45. dropout=0.3, bidirectional=True)
  46. if use_ca:
  47. self.cond_op_att = nn.Linear(N_h, N_h)
  48. else:
  49. self.cond_op_att = nn.Linear(N_h, 1)
  50. self.cond_op_out_K = nn.Linear(N_h, N_h)
  51. self.cond_op_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  52. num_layers=N_depth, batch_first=True,
  53. dropout=0.3, bidirectional=True)
  54. self.cond_op_out_col = nn.Linear(N_h, N_h)
  55. self.cond_op_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(),
  56. nn.Linear(N_h, 4))
  57. self.cond_str_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  58. num_layers=N_depth, batch_first=True,
  59. dropout=0.3, bidirectional=True)
  60. self.cond_str_decoder = nn.LSTM(input_size=self.max_tok_num,
  61. hidden_size=N_h, num_layers=N_depth,
  62. batch_first=True, dropout=0.3)
  63. self.cond_str_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  64. num_layers=N_depth, batch_first=True,
  65. dropout=0.3, bidirectional=True)
  66. self.cond_str_out_g = nn.Linear(N_h, N_h)
  67. self.cond_str_out_h = nn.Linear(N_h, N_h)
  68. self.cond_str_out_col = nn.Linear(N_h, N_h)
  69. self.cond_str_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1))
  70. self.softmax = nn.Softmax(dim=-1)
  71. def gen_gt_batch(self, split_tok_seq):
  72. B = len(split_tok_seq)
  73. max_len = max([max([len(tok) for tok in tok_seq]+[0]) for
  74. tok_seq in split_tok_seq]) - 1 # The max seq len in the batch.
  75. if max_len < 1:
  76. max_len = 1
  77. ret_array = np.zeros((
  78. B, 4, max_len, self.max_tok_num), dtype=np.float32)
  79. ret_len = np.zeros((B, 4))
  80. for b, tok_seq in enumerate(split_tok_seq):
  81. idx = 0
  82. for idx, one_tok_seq in enumerate(tok_seq):
  83. out_one_tok_seq = one_tok_seq[:-1]
  84. ret_len[b, idx] = len(out_one_tok_seq)
  85. for t, tok_id in enumerate(out_one_tok_seq):
  86. ret_array[b, idx, t, tok_id] = 1
  87. if idx < 3:
  88. ret_array[b, idx+1:, 0, 1] = 1
  89. ret_len[b, idx+1:] = 1
  90. ret_inp = torch.from_numpy(ret_array)
  91. if self.gpu:
  92. ret_inp = ret_inp.cuda()
  93. ret_inp_var = Variable(ret_inp)
  94. return ret_inp_var, ret_len #[B, IDX, max_len, max_tok_num]
  95. def forward(self, x_emb_var, x_len, col_inp_var, col_name_len,
  96. col_len, col_num, gt_where, gt_cond, reinforce):
  97. max_x_len = max(x_len)
  98. B = len(x_len)
  99. if reinforce:
  100. raise NotImplementedError('Our model doesn\'t have RL')
  101. # Predict the number of conditions
  102. # First use column embeddings to calculate the initial hidden unit
  103. # Then run the LSTM and predict condition number.
  104. e_num_col, col_num = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_num_name_enc)
  105. num_col_att_val = self.cond_num_col_att(e_num_col).squeeze()
  106. for idx, num in enumerate(col_num):
  107. if num < max(col_num):
  108. num_col_att_val[idx, num:] = -100
  109. num_col_att = self.softmax(num_col_att_val)
  110. K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
  111. cond_num_h1 = self.cond_num_col2hid1(K_num_col).view(
  112. B, 4, self.N_h/2).transpose(0, 1).contiguous()
  113. cond_num_h2 = self.cond_num_col2hid2(K_num_col).view(
  114. B, 4, self.N_h/2).transpose(0, 1).contiguous()
  115. h_num_enc, _ = run_lstm(self.cond_num_lstm, x_emb_var, x_len,
  116. hidden=(cond_num_h1, cond_num_h2))
  117. num_att_val = self.cond_num_att(h_num_enc).squeeze()
  118. for idx, num in enumerate(x_len):
  119. if num < max_x_len:
  120. num_att_val[idx, num:] = -100
  121. num_att = self.softmax(num_att_val)
  122. K_cond_num = (h_num_enc * num_att.unsqueeze(2).expand_as(h_num_enc)).sum(1)
  123. cond_num_score = self.cond_num_out(K_cond_num)
  124. #Predict the columns of conditions
  125. e_cond_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.cond_col_name_enc)
  126. h_col_enc, _ = run_lstm(self.cond_col_lstm, x_emb_var, x_len)
  127. if self.use_ca:
  128. col_att_val = torch.bmm(e_cond_col,
  129. self.cond_col_att(h_col_enc).transpose(1, 2))
  130. for idx, num in enumerate(x_len):
  131. if num < max_x_len:
  132. col_att_val[idx, :, num:] = -100
  133. col_att = self.softmax(col_att_val.view(
  134. (-1, max_x_len))).view(B, -1, max_x_len)
  135. K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2)
  136. else:
  137. col_att_val = self.cond_col_att(h_col_enc).squeeze()
  138. for idx, num in enumerate(x_len):
  139. if num < max_x_len:
  140. col_att_val[idx, num:] = -100
  141. col_att = self.softmax(col_att_val)
  142. K_cond_col = (h_col_enc *
  143. col_att_val.unsqueeze(2)).sum(1).unsqueeze(1)
  144. cond_col_score = self.cond_col_out(self.cond_col_out_K(K_cond_col) +
  145. self.cond_col_out_col(e_cond_col)).squeeze()
  146. max_col_num = max(col_num)
  147. for b, num in enumerate(col_num):
  148. if num < max_col_num:
  149. cond_col_score[b, num:] = -100
  150. #Predict the operator of conditions
  151. chosen_col_gt = []
  152. if gt_cond is None:
  153. cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1)
  154. col_scores = cond_col_score.data.cpu().numpy()
  155. chosen_col_gt = [list(np.argsort(-col_scores[b])[:cond_nums[b]])
  156. for b in range(len(cond_nums))]
  157. else:
  158. # print gt_cond
  159. chosen_col_gt = [[x[0] for x in one_gt_cond] for one_gt_cond in gt_cond]
  160. e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
  161. col_len, self.cond_op_name_enc)
  162. h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len)
  163. col_emb = []
  164. for b in range(B):
  165. cur_col_emb = torch.stack([e_cond_col[b, x]
  166. for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] *
  167. (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4)
  168. col_emb.append(cur_col_emb)
  169. col_emb = torch.stack(col_emb)
  170. if self.use_ca:
  171. op_att_val = torch.matmul(self.cond_op_att(h_op_enc).unsqueeze(1),
  172. col_emb.unsqueeze(3)).squeeze()
  173. for idx, num in enumerate(x_len):
  174. if num < max_x_len:
  175. op_att_val[idx, :, num:] = -100
  176. op_att = self.softmax(op_att_val.view(B*4, -1)).view(B, 4, -1)
  177. K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2)
  178. else:
  179. op_att_val = self.cond_op_att(h_op_enc).squeeze()
  180. for idx, num in enumerate(x_len):
  181. if num < max_x_len:
  182. op_att_val[idx, num:] = -100
  183. op_att = self.softmax(op_att_val)
  184. K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1)
  185. cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) +
  186. self.cond_op_out_col(col_emb)).squeeze()
  187. #Predict the string of conditions
  188. h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len)
  189. e_cond_col, _ = col_name_encode(col_inp_var, col_name_len,
  190. col_len, self.cond_str_name_enc)
  191. col_emb = []
  192. for b in range(B):
  193. cur_col_emb = torch.stack([e_cond_col[b, x] for x in chosen_col_gt[b]] +
  194. [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b])))
  195. col_emb.append(cur_col_emb)
  196. col_emb = torch.stack(col_emb)
  197. if gt_where is not None:
  198. gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where)
  199. g_str_s_flat, _ = self.cond_str_decoder(
  200. gt_tok_seq.view(B*4, -1, self.max_tok_num))
  201. g_str_s = g_str_s_flat.contiguous().view(B, 4, -1, self.N_h)
  202. h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
  203. g_ext = g_str_s.unsqueeze(3)
  204. col_ext = col_emb.unsqueeze(2).unsqueeze(2)
  205. cond_str_score = self.cond_str_out(
  206. self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext) +
  207. self.cond_str_out_col(col_ext)).squeeze()
  208. for b, num in enumerate(x_len):
  209. if num < max_x_len:
  210. cond_str_score[b, :, :, num:] = -100
  211. else:
  212. h_ext = h_str_enc.unsqueeze(1).unsqueeze(1)
  213. col_ext = col_emb.unsqueeze(2).unsqueeze(2)
  214. scores = []
  215. t = 0
  216. init_inp = np.zeros((B*4, 1, self.max_tok_num), dtype=np.float32)
  217. init_inp[:,0,0] = 1 #Set the <BEG> token
  218. if self.gpu:
  219. cur_inp = Variable(torch.from_numpy(init_inp).cuda())
  220. else:
  221. cur_inp = Variable(torch.from_numpy(init_inp))
  222. cur_h = None
  223. while t < 50:
  224. if cur_h:
  225. g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp, cur_h)
  226. else:
  227. g_str_s_flat, cur_h = self.cond_str_decoder(cur_inp)
  228. g_str_s = g_str_s_flat.view(B, 4, 1, self.N_h)
  229. g_ext = g_str_s.unsqueeze(3)
  230. cur_cond_str_score = self.cond_str_out(
  231. self.cond_str_out_h(h_ext) + self.cond_str_out_g(g_ext)
  232. + self.cond_str_out_col(col_ext)).squeeze()
  233. for b, num in enumerate(x_len):
  234. if num < max_x_len:
  235. cur_cond_str_score[b, :, num:] = -100
  236. scores.append(cur_cond_str_score)
  237. _, ans_tok_var = cur_cond_str_score.view(B*4, max_x_len).max(1)
  238. ans_tok = ans_tok_var.data.cpu()
  239. data = torch.zeros(B*4, self.max_tok_num).scatter_(
  240. 1, ans_tok.unsqueeze(1), 1)
  241. if self.gpu: #To one-hot
  242. cur_inp = Variable(data.cuda())
  243. else:
  244. cur_inp = Variable(data)
  245. cur_inp = cur_inp.unsqueeze(1)
  246. t += 1
  247. cond_str_score = torch.stack(scores, 2)
  248. for b, num in enumerate(x_len):
  249. if num < max_x_len:
  250. cond_str_score[b, :, :, num:] = -100 #[B, IDX, T, TOK_NUM]
  251. cond_score = (cond_num_score,
  252. cond_col_score, cond_op_score, cond_str_score)
  253. return cond_score