word_embedding.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import json
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.autograd import Variable
  6. import numpy as np
  7. class WordEmbedding(nn.Module):
  8. def __init__(self, word_emb, N_word, gpu, SQL_TOK, our_model, trainable=False):
  9. super(WordEmbedding, self).__init__()
  10. self.trainable = trainable
  11. self.N_word = N_word
  12. self.our_model = our_model
  13. self.gpu = gpu
  14. self.SQL_TOK = SQL_TOK
  15. if trainable:
  16. print "Using trainable embedding"
  17. self.w2i, word_emb_val = word_emb
  18. self.embedding = nn.Embedding(len(self.w2i), N_word)
  19. self.embedding.weight = nn.Parameter(
  20. torch.from_numpy(word_emb_val.astype(np.float32)))
  21. else:
  22. self.word_emb = word_emb
  23. print "Using fixed embedding"
  24. def gen_x_batch(self, q, col):
  25. B = len(q)
  26. val_embs = []
  27. val_len = np.zeros(B, dtype=np.int64)
  28. for i, (one_q, one_col) in enumerate(zip(q, col)):
  29. if self.trainable:
  30. q_val = map(lambda x:self.w2i.get(x, 0), one_q)
  31. else:
  32. q_val = map(lambda x:self.word_emb.get(x, np.zeros(self.N_word, dtype=np.float32)), one_q)
  33. if self.our_model:
  34. if self.trainable:
  35. val_embs.append([1] + q_val + [2]) #<BEG> and <END>
  36. else:
  37. val_embs.append([np.zeros(self.N_word, dtype=np.float32)] + q_val + [np.zeros(self.N_word, dtype=np.float32)]) #<BEG> and <END>
  38. val_len[i] = 1 + len(q_val) + 1
  39. else:
  40. one_col_all = [x for toks in one_col for x in toks+[',']]
  41. if self.trainable:
  42. col_val = map(lambda x:self.w2i.get(x, 0), one_col_all)
  43. val_embs.append( [0 for _ in self.SQL_TOK] + col_val + [0] + q_val+ [0])
  44. else:
  45. col_val = map(lambda x:self.word_emb.get(x, np.zeros(self.N_word, dtype=np.float32)), one_col_all)
  46. val_embs.append( [np.zeros(self.N_word, dtype=np.float32) for _ in self.SQL_TOK] + col_val + [np.zeros(self.N_word, dtype=np.float32)] + q_val+ [np.zeros(self.N_word, dtype=np.float32)])
  47. val_len[i] = len(self.SQL_TOK) + len(col_val) + 1 + len(q_val) + 1
  48. max_len = max(val_len)
  49. if self.trainable:
  50. val_tok_array = np.zeros((B, max_len), dtype=np.int64)
  51. for i in range(B):
  52. for t in range(len(val_embs[i])):
  53. val_tok_array[i,t] = val_embs[i][t]
  54. val_tok = torch.from_numpy(val_tok_array)
  55. if self.gpu:
  56. val_tok = val_tok.cuda()
  57. val_tok_var = Variable(val_tok)
  58. val_inp_var = self.embedding(val_tok_var)
  59. else:
  60. val_emb_array = np.zeros((B, max_len, self.N_word), dtype=np.float32)
  61. for i in range(B):
  62. for t in range(len(val_embs[i])):
  63. val_emb_array[i,t,:] = val_embs[i][t]
  64. val_inp = torch.from_numpy(val_emb_array)
  65. if self.gpu:
  66. val_inp = val_inp.cuda()
  67. val_inp_var = Variable(val_inp)
  68. return val_inp_var, val_len
  69. def gen_col_batch(self, cols):
  70. ret = []
  71. col_len = np.zeros(len(cols), dtype=np.int64)
  72. names = []
  73. for b, one_cols in enumerate(cols):
  74. names = names + one_cols
  75. col_len[b] = len(one_cols)
  76. name_inp_var, name_len = self.str_list_to_batch(names)
  77. return name_inp_var, name_len, col_len
  78. def str_list_to_batch(self, str_list):
  79. B = len(str_list)
  80. val_embs = []
  81. val_len = np.zeros(B, dtype=np.int64)
  82. for i, one_str in enumerate(str_list):
  83. if self.trainable:
  84. val = [self.w2i.get(x, 0) for x in one_str]
  85. else:
  86. val = [self.word_emb.get(x, np.zeros(
  87. self.N_word, dtype=np.float32)) for x in one_str]
  88. val_embs.append(val)
  89. val_len[i] = len(val)
  90. max_len = max(val_len)
  91. if self.trainable:
  92. val_tok_array = np.zeros((B, max_len), dtype=np.int64)
  93. for i in range(B):
  94. for t in range(len(val_embs[i])):
  95. val_tok_array[i,t] = val_embs[i][t]
  96. val_tok = torch.from_numpy(val_tok_array)
  97. if self.gpu:
  98. val_tok = val_tok.cuda()
  99. val_tok_var = Variable(val_tok)
  100. val_inp_var = self.embedding(val_tok_var)
  101. else:
  102. val_emb_array = np.zeros(
  103. (B, max_len, self.N_word), dtype=np.float32)
  104. for i in range(B):
  105. for t in range(len(val_embs[i])):
  106. val_emb_array[i,t,:] = val_embs[i][t]
  107. val_inp = torch.from_numpy(val_emb_array)
  108. if self.gpu:
  109. val_inp = val_inp.cuda()
  110. val_inp_var = Variable(val_inp)
  111. return val_inp_var, val_len