seq2sql_condition_predict.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import json
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. import numpy as np
  7. from net_utils import run_lstm
  8. class Seq2SQLCondPredictor(nn.Module):
  9. def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, gpu):
  10. super(Seq2SQLCondPredictor, self).__init__()
  11. print "Seq2SQL where prediction"
  12. self.N_h = N_h
  13. self.max_tok_num = max_tok_num
  14. self.max_col_num = max_col_num
  15. self.gpu = gpu
  16. self.cond_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  17. num_layers=N_depth, batch_first=True,
  18. dropout=0.3, bidirectional=True)
  19. self.cond_decoder = nn.LSTM(input_size=self.max_tok_num,
  20. hidden_size=N_h, num_layers=N_depth,
  21. batch_first=True, dropout=0.3)
  22. self.cond_out_g = nn.Linear(N_h, N_h)
  23. self.cond_out_h = nn.Linear(N_h, N_h)
  24. self.cond_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1))
  25. self.softmax = nn.Softmax()
  26. def gen_gt_batch(self, tok_seq, gen_inp=True):
  27. # If gen_inp: generate the input token sequence (removing <END>)
  28. # Otherwise: generate the output token sequence (removing <BEG>)
  29. B = len(tok_seq)
  30. ret_len = np.array([len(one_tok_seq)-1 for one_tok_seq in tok_seq])
  31. max_len = max(ret_len)
  32. ret_array = np.zeros((B, max_len, self.max_tok_num), dtype=np.float32)
  33. for b, one_tok_seq in enumerate(tok_seq):
  34. out_one_tok_seq = one_tok_seq[:-1] if gen_inp else one_tok_seq[1:]
  35. for t, tok_id in enumerate(out_one_tok_seq):
  36. ret_array[b, t, tok_id] = 1
  37. ret_inp = torch.from_numpy(ret_array)
  38. if self.gpu:
  39. ret_inp = ret_inp.cuda()
  40. ret_inp_var = Variable(ret_inp) #[B, max_len, max_tok_num]
  41. return ret_inp_var, ret_len
  42. def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len,
  43. col_num, gt_where, gt_cond, reinforce):
  44. max_x_len = max(x_len)
  45. B = len(x_len)
  46. h_enc, hidden = run_lstm(self.cond_lstm, x_emb_var, x_len)
  47. decoder_hidden = tuple(torch.cat((hid[:2], hid[2:]),dim=2)
  48. for hid in hidden)
  49. if gt_where is not None:
  50. gt_tok_seq, gt_tok_len = self.gen_gt_batch(gt_where, gen_inp=True)
  51. g_s, _ = run_lstm(self.cond_decoder,
  52. gt_tok_seq, gt_tok_len, decoder_hidden)
  53. h_enc_expand = h_enc.unsqueeze(1)
  54. g_s_expand = g_s.unsqueeze(2)
  55. cond_score = self.cond_out( self.cond_out_h(h_enc_expand) +
  56. self.cond_out_g(g_s_expand) ).squeeze()
  57. for idx, num in enumerate(x_len):
  58. if num < max_x_len:
  59. cond_score[idx, :, num:] = -100
  60. else:
  61. h_enc_expand = h_enc.unsqueeze(1)
  62. scores = []
  63. choices = []
  64. done_set = set()
  65. t = 0
  66. init_inp = np.zeros((B, 1, self.max_tok_num), dtype=np.float32)
  67. init_inp[:,0,7] = 1 #Set the <BEG> token
  68. if self.gpu:
  69. cur_inp = Variable(torch.from_numpy(init_inp).cuda())
  70. else:
  71. cur_inp = Variable(torch.from_numpy(init_inp))
  72. cur_h = decoder_hidden
  73. while len(done_set) < B and t < 100:
  74. g_s, cur_h = self.cond_decoder(cur_inp, cur_h)
  75. g_s_expand = g_s.unsqueeze(2)
  76. cur_cond_score = self.cond_out(self.cond_out_h(h_enc_expand) +
  77. self.cond_out_g(g_s_expand)).squeeze()
  78. for b, num in enumerate(x_len):
  79. if num < max_x_len:
  80. cur_cond_score[b, num:] = -100
  81. scores.append(cur_cond_score)
  82. if not reinforce:
  83. _, ans_tok_var = cur_cond_score.view(B, max_x_len).max(1)
  84. ans_tok_var = ans_tok_var.unsqueeze(1)
  85. else:
  86. ans_tok_var = self.softmax(cur_cond_score).multinomial()
  87. choices.append(ans_tok_var)
  88. ans_tok = ans_tok_var.data.cpu()
  89. if self.gpu: #To one-hot
  90. cur_inp = Variable(torch.zeros(
  91. B, self.max_tok_num).scatter_(1, ans_tok, 1).cuda())
  92. else:
  93. cur_inp = Variable(torch.zeros(
  94. B, self.max_tok_num).scatter_(1, ans_tok, 1))
  95. cur_inp = cur_inp.unsqueeze(1)
  96. for idx, tok in enumerate(ans_tok.squeeze()):
  97. if tok == 1: #Find the <END> token
  98. done_set.add(idx)
  99. t += 1
  100. cond_score = torch.stack(scores, 1)
  101. if reinforce:
  102. return cond_score, choices
  103. else:
  104. return cond_score