train.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import numpy as np
  3. import tensorflow as tf
  4. from poem.model import rnn_model
  5. from poem.poem import process_poem, generate_batch
  6. tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')
  7. tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')
  8. tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.')
  9. tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poem.txt'), 'file name of poem.')
  10. tf.app.flags.DEFINE_string('model_prefix', 'poem', 'model save prefix.')
  11. tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.')
  12. FLAGS = tf.app.flags.FLAGS
  13. def run_training():
  14. if not os.path.exists(FLAGS.model_dir):
  15. os.makedirs(FLAGS.model_dir)
  16. poem_vector, word_to_int, vocabularies = process_poem(FLAGS.file_path)
  17. batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poem_vector, word_to_int)
  18. input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
  19. output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
  20. end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
  21. vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
  22. saver = tf.train.Saver(tf.global_variables())
  23. init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
  24. with tf.Session() as sess:
  25. # sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
  26. # sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
  27. sess.run(init_op)
  28. start_epoch = 0
  29. checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
  30. if checkpoint:
  31. saver.restore(sess, checkpoint)
  32. print("## restore from the checkpoint {0}".format(checkpoint))
  33. start_epoch += int(checkpoint.split('-')[-1])
  34. print('## start training...')
  35. try:
  36. for epoch in range(start_epoch, FLAGS.epochs):
  37. n = 0
  38. n_chunk = len(poem_vector) // FLAGS.batch_size
  39. for batch in range(n_chunk):
  40. loss, _, _ = sess.run([
  41. end_points['total_loss'],
  42. end_points['last_state'],
  43. end_points['train_op']
  44. ], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
  45. n += 1
  46. print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
  47. if epoch % 6 == 0:
  48. saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
  49. except KeyboardInterrupt:
  50. print('## Interrupt manually, try saving checkpoint for now...')
  51. saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
  52. print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
  53. def main(_):
  54. run_training()
  55. if __name__ == '__main__':
  56. tf.app.run()