compose_poem.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import tensorflow as tf
  2. from poem.model import rnn_model
  3. from poem.poem import process_poem
  4. import numpy as np
  5. start_token = 'B'
  6. end_token = 'E'
  7. model_dir = './model/'
  8. corpus_file = './data/poem.txt'
  9. lr = 0.0002
  10. def to_word(predict, vocabs):
  11. t = np.cumsum(predict)
  12. s = np.sum(predict)
  13. sample = int(np.searchsorted(t, np.random.rand(1) * s))
  14. if sample > len(vocabs):
  15. sample = len(vocabs) - 1
  16. return vocabs[sample]
  17. def gen_poem(begin_word):
  18. batch_size = 1
  19. print('## loading corpus from %s' % model_dir)
  20. poem_vector, word_int_map, vocabularies = process_poem(corpus_file)
  21. input_data = tf.placeholder(tf.int32, [batch_size, None])
  22. end_points = rnn_model(model='lstm', input_data=input_data, output_data=None, vocab_size=len(
  23. vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=lr)
  24. saver = tf.train.Saver(tf.global_variables())
  25. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  26. with tf.Session() as sess:
  27. sess.run(init_op)
  28. checkpoint = tf.train.latest_checkpoint(model_dir)
  29. saver.restore(sess, checkpoint)
  30. x = np.array([list(map(word_int_map.get, start_token))])
  31. [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
  32. feed_dict={input_data: x})
  33. if begin_word:
  34. word = begin_word
  35. else:
  36. word = to_word(predict, vocabularies)
  37. poem_ = ''
  38. i = 0
  39. while word != end_token:
  40. poem_ += word
  41. i += 1
  42. if i >= 24:
  43. break
  44. x = np.zeros((1, 1))
  45. x[0, 0] = word_int_map[word]
  46. [predict, last_state] = sess.run([end_points['prediction'], end_points['last_state']],
  47. feed_dict={input_data: x, end_points['initial_state']: last_state})
  48. word = to_word(predict, vocabularies)
  49. return poem_
  50. def pretty_print_poem(poem_):
  51. poem_sentences = poem_.split('。')
  52. for s in poem_sentences:
  53. if s != '' and len(s) > 10:
  54. print(s + '。')
  55. if __name__ == '__main__':
  56. begin_char = input('## please input the first character:')
  57. poem = gen_poem(begin_char)
  58. pretty_print_poem(poem_=poem)