waynesun 4 years ago
parent
commit
bed8c7ecc8
2 changed files with 17 additions and 17 deletions
  1. 5 5
      test.py
  2. 12 12
      train.py

+ 5 - 5
test.py

@@ -28,13 +28,13 @@ if __name__ == '__main__':
     model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
 
     model_path = 'saved_model/best_model'
-    print "Loading from %s" % model_path
+    print ("Loading from %s" % model_path)
     model.load_state_dict(torch.load(model_path))
-    print "Loaded model from %s" % model_path
+    print ("Loaded model from %s" % model_path)
 
     dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
-    print 'Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2])
+    print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
 
-    print "Start to predict test set"
+    print ("Start to predict test set")
     predict_test(model, batch_size, test_sql, test_table, args.output_dir)
-    print "Output path of prediction result is %s" % args.output_dir
+    print ("Output path of prediction result is %s" % args.output_dir)

+ 12 - 12
train.py

@@ -30,13 +30,13 @@ if __name__ == '__main__':
     # load dataset
     train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
 
-    word_emb = load_word_emb('data/char_embedding')
+    word_emb = load_word_emb('data/char_embedding.json')
     model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
     optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
 
     if args.restore:
         model_path= 'saved_model/best_model'
-        print "Loading trained model from %s" % model_path
+        print ("Loading trained model from %s" % model_path)
         model.load_state_dict(torch.load(model_path))
 
     # used to record best score of each sub-task
@@ -45,16 +45,16 @@ if __name__ == '__main__':
     best_lf, best_lf_idx = 0.0, 0
     best_ex, best_ex_idx = 0.0, 0
 
-    print "#"*20+"  Star to Train  " + "#"*20
+    print ("#"*20+"  Star to Train  " + "#"*20)
     for i in range(args.epoch):
-        print 'Epoch %d'%(i+1)
+        print ('Epoch %d'%(i+1))
         # train on the train dataset
         train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
         # evaluate on the dev dataset
         dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
         # accuracy of each sub-task
-        print 'Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
-            dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7])
+        print ('Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
+            dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7]))
         # save the best model
         if dev_acc[1] > best_lf:
             best_lf = dev_acc[1]
@@ -90,11 +90,11 @@ if __name__ == '__main__':
             if dev_acc[0][7] > best_wr:
                 best_wr = dev_acc[0][7]
                 best_wr_idx = i+1
-        print 'Train loss = %.3f' % train_loss
-        print 'Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2])
-        print 'Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx)
-        print 'Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx)
+        print ('Train loss = %.3f' % train_loss)
+        print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
+        print ('Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx))
+        print ('Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx))
         if (i+1) % 10 == 0:
-            print 'Best val acc: %s\nOn epoch individually %s'%(
+            print ('Best val acc: %s\nOn epoch individually %s'%(
                     (best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
-                    (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx))
+                    (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx)))