train.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import torch
  2. from sqlnet.utils import *
  3. from sqlnet.model.sqlnet import SQLNet
  4. import argparse
  5. if __name__ == '__main__':
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument('--bs', type=int, default=16, help='Batch size')
  8. parser.add_argument('--epoch', type=int, default=100, help='Epoch number')
  9. parser.add_argument('--gpu', action='store_true', help='Whether use gpu to train')
  10. parser.add_argument('--toy', action='store_true', help='If set, use small data for fast debugging')
  11. parser.add_argument('--ca', action='store_true', help='Whether use column attention')
  12. parser.add_argument('--train_emb', action='store_true', help='Train word embedding for SQLNet')
  13. parser.add_argument('--restore', action='store_true', help='Whether restore trained model')
  14. parser.add_argument('--logdir', type=str, default='', help='Path of save experiment log')
  15. args = parser.parse_args()
  16. n_word=300
  17. if args.toy:
  18. use_small=True
  19. gpu=args.gpu
  20. batch_size=16
  21. else:
  22. use_small=False
  23. gpu=args.gpu
  24. batch_size=args.bs
  25. learning_rate = 1e-3
  26. # load dataset
  27. train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
  28. word_emb = load_word_emb('data/char_embedding.json')
  29. model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
  30. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
  31. if args.restore:
  32. model_path= 'saved_model/best_model'
  33. print ("Loading trained model from %s" % model_path)
  34. model.load_state_dict(torch.load(model_path))
  35. # used to record best score of each sub-task
  36. best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv, best_wr = 0, 0, 0, 0, 0, 0, 0, 0
  37. best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx, best_wr_idx = 0, 0, 0, 0, 0, 0, 0, 0
  38. best_lf, best_lf_idx = 0.0, 0
  39. best_ex, best_ex_idx = 0.0, 0
  40. print ("#"*20+" Star to Train " + "#"*20)
  41. for i in range(args.epoch):
  42. print ('Epoch %d'%(i+1))
  43. # train on the train dataset
  44. train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
  45. # evaluate on the dev dataset
  46. dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
  47. # accuracy of each sub-task
  48. print ('Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
  49. dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7]))
  50. # save the best model
  51. if dev_acc[1] > best_lf:
  52. best_lf = dev_acc[1]
  53. best_lf_idx = i + 1
  54. torch.save(model.state_dict(), 'saved_model/best_model')
  55. if dev_acc[2] > best_ex:
  56. best_ex = dev_acc[2]
  57. best_ex_idx = i + 1
  58. # record the best score of each sub-task
  59. if True:
  60. if dev_acc[0][0] > best_sn:
  61. best_sn = dev_acc[0][0]
  62. best_sn_idx = i+1
  63. if dev_acc[0][1] > best_sc:
  64. best_sc = dev_acc[0][1]
  65. best_sc_idx = i+1
  66. if dev_acc[0][2] > best_sa:
  67. best_sa = dev_acc[0][2]
  68. best_sa_idx = i+1
  69. if dev_acc[0][3] > best_wn:
  70. best_wn = dev_acc[0][3]
  71. best_wn_idx = i+1
  72. if dev_acc[0][4] > best_wc:
  73. best_wc = dev_acc[0][4]
  74. best_wc_idx = i+1
  75. if dev_acc[0][5] > best_wo:
  76. best_wo = dev_acc[0][5]
  77. best_wo_idx = i+1
  78. if dev_acc[0][6] > best_wv:
  79. best_wv = dev_acc[0][6]
  80. best_wv_idx = i+1
  81. if dev_acc[0][7] > best_wr:
  82. best_wr = dev_acc[0][7]
  83. best_wr_idx = i+1
  84. print ('Train loss = %.3f' % train_loss)
  85. print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
  86. print ('Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx))
  87. print ('Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx))
  88. if (i+1) % 10 == 0:
  89. print ('Best val acc: %s\nOn epoch individually %s'%(
  90. (best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
  91. (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx)))