train.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. '''
  4. @File : train.py
  5. @Time : 2019/06/25 04:00:53
  6. @Author : Liuyuqi
  7. @Version : 1.0
  8. @Contact : liuyuqi.gov@msn.cn
  9. @License : (C)Copyright 2019
  10. @Desc :
  11. '''
  12. import torch
  13. from sqlnet.utils import *
  14. from sqlnet.model.sqlnet import SQLNet
  15. import argparse
  16. if __name__ == '__main__':
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('--bs', type=int, default=16, help='Batch size')#
  19. parser.add_argument('--epoch', type=int, default=100, help='Epoch number')
  20. parser.add_argument('--gpu', action='store_true', help='Whether use gpu to train')#
  21. parser.add_argument('--toy', action='store_true', help='If set, use small data for fast debugging')
  22. parser.add_argument('--ca', action='store_true', help='Whether use column attention')#
  23. parser.add_argument('--train_emb', action='store_true', help='Train word embedding for SQLNet')
  24. parser.add_argument('--restore', action='store_true', help='Whether restore trained model')
  25. parser.add_argument('--logdir', type=str, default='', help='Path of save experiment log')
  26. args = parser.parse_args()
  27. n_word=300
  28. if args.toy:
  29. use_small=True
  30. gpu=args.gpu
  31. batch_size=16
  32. else:
  33. use_small=False
  34. gpu=args.gpu
  35. batch_size=args.bs
  36. learning_rate = 1e-3
  37. # load dataset 加载训练数据和测试数据
  38. train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
  39. # word_emb 字典类型。
  40. word_emb = load_word_emb('data/char_embedding.json')
  41. model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
  42. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
  43. if args.restore:
  44. model_path= 'saved_model/best_model'
  45. print ("Loading trained model from %s" % model_path)
  46. model.load_state_dict(torch.load(model_path))
  47. # used to record best score of each sub-task
  48. best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv, best_wr = 0, 0, 0, 0, 0, 0, 0, 0
  49. best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx, best_wr_idx = 0, 0, 0, 0, 0, 0, 0, 0
  50. best_lf, best_lf_idx = 0.0, 0
  51. best_ex, best_ex_idx = 0.0, 0
  52. print ("#"*20+" Star to Train " + "#"*20)
  53. for i in range(args.epoch):# range(100)
  54. print ('Epoch %d'%(i+1))
  55. # train on the train dataset
  56. train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
  57. # evaluate on the dev dataset
  58. dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
  59. # accuracy of each sub-task
  60. print ('Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
  61. dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7]))
  62. # save the best model
  63. if dev_acc[1] > best_lf:
  64. best_lf = dev_acc[1]
  65. best_lf_idx = i + 1
  66. torch.save(model.state_dict(), 'saved_model/best_model')
  67. if dev_acc[2] > best_ex:
  68. best_ex = dev_acc[2]
  69. best_ex_idx = i + 1
  70. # record the best score of each sub-task
  71. if True:
  72. if dev_acc[0][0] > best_sn:
  73. best_sn = dev_acc[0][0]
  74. best_sn_idx = i+1
  75. if dev_acc[0][1] > best_sc:
  76. best_sc = dev_acc[0][1]
  77. best_sc_idx = i+1
  78. if dev_acc[0][2] > best_sa:
  79. best_sa = dev_acc[0][2]
  80. best_sa_idx = i+1
  81. if dev_acc[0][3] > best_wn:
  82. best_wn = dev_acc[0][3]
  83. best_wn_idx = i+1
  84. if dev_acc[0][4] > best_wc:
  85. best_wc = dev_acc[0][4]
  86. best_wc_idx = i+1
  87. if dev_acc[0][5] > best_wo:
  88. best_wo = dev_acc[0][5]
  89. best_wo_idx = i+1
  90. if dev_acc[0][6] > best_wv:
  91. best_wv = dev_acc[0][6]
  92. best_wv_idx = i+1
  93. if dev_acc[0][7] > best_wr:
  94. best_wr = dev_acc[0][7]
  95. best_wr_idx = i+1
  96. print ('Train loss = %.3f' % train_loss)
  97. print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
  98. print ('Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx))
  99. print ('Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx))
  100. if (i+1) % 10 == 0:
  101. print ('Best val acc: %s\nOn epoch individually %s'%(
  102. (best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
  103. (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx)))