where_relation.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import json
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. import numpy as np
  7. from net_utils import run_lstm, col_name_encode
  8. class WhereRelationPredictor(nn.Module):
  9. def __init__(self, N_word, N_h, N_depth, use_ca):
  10. super(WhereRelationPredictor, self).__init__()
  11. self.N_h = N_h
  12. self.use_ca = use_ca
  13. self.where_rela_lstm = nn.LSTM(input_size=N_word, hidden_size=N_h/2,
  14. num_layers=N_depth, batch_first=True,
  15. dropout=0.3, bidirectional=True)
  16. self.where_rela_att = nn.Linear(N_h, 1)
  17. self.where_rela_col_att = nn.Linear(N_h, 1)
  18. self.where_rela_out = nn.Sequential(nn.Linear(N_h, N_h), nn.Tanh(), nn.Linear(N_h,3))
  19. self.softmax = nn.Softmax(dim=-1)
  20. self.col2hid1 = nn.Linear(N_h, 2 * N_h)
  21. self.col2hid2 = nn.Linear(N_h, 2 * N_h)
  22. if self.use_ca:
  23. print "Using column attention on where relation predicting"
  24. def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_len, col_num):
  25. B = len(x_len)
  26. max_x_len = max(x_len)
  27. # Predict the condition relationship part
  28. # First use column embeddings to calculate the initial hidden unit
  29. # Then run the LSTM and predict select number
  30. e_num_col, col_num = col_name_encode(col_inp_var, col_name_len,
  31. col_len, self.where_rela_lstm)
  32. col_att_val = self.where_rela_col_att(e_num_col).squeeze()
  33. for idx, num in enumerate(col_num):
  34. if num < max(col_num):
  35. col_att_val[idx, num:] = -1000000
  36. num_col_att = self.softmax(col_att_val)
  37. K_num_col = (e_num_col * num_col_att.unsqueeze(2)).sum(1)
  38. h1 = self.col2hid1(K_num_col).view(B, 4, self.N_h/2).transpose(0,1).contiguous()
  39. h2 = self.col2hid2(K_num_col).view(B, 4, self.N_h/2).transpose(0,1).contiguous()
  40. h_enc, _ = run_lstm(self.where_rela_lstm, x_emb_var, x_len, hidden=(h1, h2))
  41. att_val = self.where_rela_att(h_enc).squeeze()
  42. for idx, num in enumerate(x_len):
  43. if num < max_x_len:
  44. att_val[idx, num:] = -1000000
  45. att_val = self.softmax(att_val)
  46. where_rela_num = (h_enc * att_val.unsqueeze(2).expand_as(h_enc)).sum(1)
  47. where_rela_score = self.where_rela_out(where_rela_num)
  48. return where_rela_score