net_utils.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. from torch.autograd import Variable
  5. def run_lstm(lstm, inp, inp_len, hidden=None):
  6. # Run the LSTM using packed sequence.
  7. # This requires to first sort the input according to its length.
  8. sort_perm = np.array(sorted(range(len(inp_len)),
  9. key=lambda k:inp_len[k], reverse=True))
  10. sort_inp_len = inp_len[sort_perm]
  11. sort_perm_inv = np.argsort(sort_perm)
  12. if inp.is_cuda:
  13. sort_perm = torch.LongTensor(sort_perm).cuda()
  14. sort_perm_inv = torch.LongTensor(sort_perm_inv).cuda()
  15. lstm_inp = nn.utils.rnn.pack_padded_sequence(inp[sort_perm],
  16. sort_inp_len, batch_first=True)
  17. if hidden is None:
  18. lstm_hidden = None
  19. else:
  20. lstm_hidden = (hidden[0][:, sort_perm], hidden[1][:, sort_perm])
  21. sort_ret_s, sort_ret_h = lstm(lstm_inp, lstm_hidden)
  22. ret_s = nn.utils.rnn.pad_packed_sequence(
  23. sort_ret_s, batch_first=True)[0][sort_perm_inv]
  24. ret_h = (sort_ret_h[0][:, sort_perm_inv], sort_ret_h[1][:, sort_perm_inv])
  25. return ret_s, ret_h
  26. def col_name_encode(name_inp_var, name_len, col_len, enc_lstm):
  27. #Encode the columns.
  28. #The embedding of a column name is the last state of its LSTM output.
  29. name_hidden, _ = run_lstm(enc_lstm, name_inp_var, name_len)
  30. name_out = name_hidden[tuple(range(len(name_len))), name_len-1]
  31. ret = torch.FloatTensor(
  32. len(col_len), max(col_len), name_out.size()[1]).zero_()
  33. if name_out.is_cuda:
  34. ret = ret.cuda()
  35. st = 0
  36. for idx, cur_len in enumerate(col_len):
  37. ret[idx, :cur_len] = name_out.data[st:st+cur_len]
  38. st += cur_len
  39. ret_var = Variable(ret)
  40. return ret_var, col_len