selection_predict.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import json
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. import numpy as np
  7. from net_utils import run_lstm, col_name_encode
  8. class SelPredictor(nn.Module):
  9. def __init__(self, N_word, N_h, N_depth, max_tok_num, use_ca):
  10. super(SelPredictor, self).__init__()
  11. self.use_ca = use_ca
  12. self.max_tok_num = max_tok_num
  13. self.sel_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  14. num_layers=N_depth, batch_first=True,
  15. dropout=0.3, bidirectional=True)
  16. if use_ca:
  17. print "Using column attention on selection predicting"
  18. self.sel_att = nn.Linear(N_h, N_h)
  19. else:
  20. print "Not using column attention on selection predicting"
  21. self.sel_att = nn.Linear(N_h, 1)
  22. self.sel_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  23. num_layers=N_depth, batch_first=True,
  24. dropout=0.3, bidirectional=True)
  25. self.sel_out_K = nn.Linear(N_h, N_h)
  26. self.sel_out_col = nn.Linear(N_h, N_h)
  27. self.sel_out = nn.Sequential(nn.Tanh(), nn.Linear(N_h, 1))
  28. self.softmax = nn.Softmax(dim=-1)
  29. def forward(self, x_emb_var, x_len, col_inp_var,
  30. col_name_len, col_len, col_num):
  31. # Based on number of selections to predict select-column
  32. B = len(x_emb_var)
  33. max_x_len = max(x_len)
  34. e_col, _ = col_name_encode(col_inp_var, col_name_len, col_len, self.sel_col_name_enc) # [bs, col_num, hid]
  35. h_enc, _ = run_lstm(self.sel_lstm, x_emb_var, x_len) # [bs, seq_len, hid]
  36. att_val = torch.bmm(e_col, self.sel_att(h_enc).transpose(1, 2)) # [bs, col_num, seq_len]
  37. for idx, num in enumerate(x_len):
  38. if num < max_x_len:
  39. att_val[idx, :, num:] = -100
  40. att = self.softmax(att_val.view((-1, max_x_len))).view(B, -1, max_x_len)
  41. K_sel_expand = (h_enc.unsqueeze(1) * att.unsqueeze(3)).sum(2)
  42. sel_score = self.sel_out( self.sel_out_K(K_sel_expand) + self.sel_out_col(e_col) ).squeeze()
  43. max_col_num = max(col_num)
  44. for idx, num in enumerate(col_num):
  45. if num < max_col_num:
  46. sel_score[idx, num:] = -100
  47. return sel_score