model.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import tensorflow as tf
  2. import numpy as np
  3. def rnn_model(model, input_data, output_data, vocab_size, rnn_size=128, num_layers=2, batch_size=64,
  4. learning_rate=0.01):
  5. """
  6. construct rnn seq2seq model.
  7. :param model: model class
  8. :param input_data: input data placeholder
  9. :param output_data: output data placeholder
  10. :param vocab_size:
  11. :param rnn_size:
  12. :param num_layers:
  13. :param batch_size:
  14. :param learning_rate:
  15. :return:
  16. """
  17. end_points = {}
  18. if model == 'rnn':
  19. cell_fun = tf.contrib.rnn.BasicRNNCell
  20. elif model == 'gru':
  21. cell_fun = tf.contrib.rnn.GRUCell
  22. elif model == 'lstm':
  23. cell_fun = tf.contrib.rnn.BasicLSTMCell
  24. cell = cell_fun(rnn_size, state_is_tuple=True)
  25. cell = tf.contrib.rnn.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
  26. if output_data is not None:
  27. initial_state = cell.zero_state(batch_size, tf.float32)
  28. else:
  29. initial_state = cell.zero_state(1, tf.float32)
  30. with tf.device("/cpu:0"):
  31. embedding = tf.get_variable('embedding', initializer=tf.random_uniform(
  32. [vocab_size + 1, rnn_size], -1.0, 1.0))
  33. inputs = tf.nn.embedding_lookup(embedding, input_data)
  34. # [batch_size, ?, rnn_size] = [64, ?, 128]
  35. outputs, last_state = tf.nn.dynamic_rnn(cell, inputs, initial_state=initial_state)
  36. output = tf.reshape(outputs, [-1, rnn_size])
  37. weights = tf.Variable(tf.truncated_normal([rnn_size, vocab_size + 1]))
  38. bias = tf.Variable(tf.zeros(shape=[vocab_size + 1]))
  39. logits = tf.nn.bias_add(tf.matmul(output, weights), bias=bias)
  40. # [?, vocab_size+1]
  41. if output_data is not None:
  42. # output_data must be one-hot encode
  43. labels = tf.one_hot(tf.reshape(output_data, [-1]), depth=vocab_size + 1)
  44. # should be [?, vocab_size+1]
  45. loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
  46. # loss shape should be [?, vocab_size+1]
  47. total_loss = tf.reduce_mean(loss)
  48. train_op = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
  49. end_points['initial_state'] = initial_state
  50. end_points['output'] = output
  51. end_points['train_op'] = train_op
  52. end_points['total_loss'] = total_loss
  53. end_points['loss'] = loss
  54. end_points['last_state'] = last_state
  55. else:
  56. prediction = tf.nn.softmax(logits)
  57. end_points['initial_state'] = initial_state
  58. end_points['last_state'] = last_state
  59. end_points['prediction'] = prediction
  60. return end_points