Browse Source

添加备注

liuyuqi-dellpc 7 years ago
parent
commit
b2bc569ea8
2 changed files with 5 additions and 5 deletions
  1. 4 4
      src/poem/poem.py
  2. 1 1
      src/train.py

+ 4 - 4
src/poem/poem.py

@@ -9,7 +9,7 @@ end_token = 'E'
 
 def process_poem(file_name):
     # 诗集
-    poem = []
+    poems = []
     with open(file_name, "r", encoding='utf-8', ) as f:
         for line in f.readlines():
             try:
@@ -21,13 +21,13 @@ def process_poem(file_name):
                 if len(content) < 5 or len(content) > 79:
                     continue
                 content = start_token + content + end_token
-                poem.append(content)
+                poems.append(content)
             except ValueError as e:
                 pass
-    poem = sorted(poem, key=lambda l: len(line))
+    poems = sorted(poems, key=lambda l: len(line))
 
     all_words = []
-    for poem in poem:
+    for poem in poems:
         all_words += [word for word in poem]
     counter = collections.Counter(all_words)
     count_pairs = sorted(counter.items(), key=lambda x: -x[1])

+ 1 - 1
src/train.py

@@ -1,9 +1,9 @@
 import os
-import numpy as np
 import tensorflow as tf
 from poem.model import rnn_model
 from poem.poem import process_poem, generate_batch
 
+# FLAGS参数设置
 tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')
 tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')
 tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.')