train.py 4.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import torch
  2. from sqlnet.utils import *
  3. from sqlnet.model.sqlnet import SQLNet
  4. import argparse
  5. if __name__ == '__main__':
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument('--bs', type=int, default=16, help='Batch size')
  8. parser.add_argument('--epoch', type=int, default=100, help='Epoch number')
  9. parser.add_argument('--gpu', action='store_true', help='Whether use gpu to train')
  10. parser.add_argument('--toy', action='store_true', help='If set, use small data for fast debugging')
  11. parser.add_argument('--ca', action='store_true', help='Whether use column attention')
  12. parser.add_argument('--train_emb', action='store_true', help='Train word embedding for SQLNet')
  13. parser.add_argument('--restore', action='store_true', help='Whether restore trained model')
  14. parser.add_argument('--logdir', type=str, default='', help='Path of save experiment log')
  15. args = parser.parse_args()
  16. n_word=300
  17. if args.toy:
  18. use_small=True
  19. gpu=args.gpu
  20. batch_size=16
  21. else:
  22. use_small=False
  23. gpu=args.gpu
  24. batch_size=args.bs
  25. learning_rate = 1e-3
  26. # load dataset
  27. train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
  28. word_emb = load_word_emb('data/char_embedding')
  29. model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
  30. optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
  31. if args.restore:
  32. model_path= 'saved_model/best_model'
  33. print "Loading trained model from %s" % model_path
  34. model.load_state_dict(torch.load(model_path))
  35. # used to record best score of each sub-task
  36. best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv, best_wr = 0, 0, 0, 0, 0, 0, 0, 0
  37. best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx, best_wr_idx = 0, 0, 0, 0, 0, 0, 0, 0
  38. best_lf, best_lf_idx = 0.0, 0
  39. print "#"*20+" Star to Train " + "#"*20
  40. for i in range(args.epoch):
  41. print 'Epoch %d'%(i+1)
  42. # train on the train dataset
  43. train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
  44. # evaluate on the dev dataset
  45. dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table)
  46. # accuracy of each sub-task
  47. print 'Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
  48. dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7])
  49. # save the best model
  50. if dev_acc[1] > best_lf:
  51. best_lf = dev_acc[1]
  52. best_lf_idx = i + 1
  53. torch.save(model.state_dict(), 'saved_model/best_model')
  54. # record the best score of each sub-task
  55. if True:
  56. if dev_acc[0][0] > best_sn:
  57. best_sn = dev_acc[0][0]
  58. best_sn_idx = i+1
  59. if dev_acc[0][1] > best_sc:
  60. best_sc = dev_acc[0][1]
  61. best_sc_idx = i+1
  62. if dev_acc[0][2] > best_sa:
  63. best_sa = dev_acc[0][2]
  64. best_sa_idx = i+1
  65. if dev_acc[0][3] > best_wn:
  66. best_wn = dev_acc[0][3]
  67. best_wn_idx = i+1
  68. if dev_acc[0][4] > best_wc:
  69. best_wc = dev_acc[0][4]
  70. best_wc_idx = i+1
  71. if dev_acc[0][5] > best_wo:
  72. best_wo = dev_acc[0][5]
  73. best_wo_idx = i+1
  74. if dev_acc[0][6] > best_wv:
  75. best_wv = dev_acc[0][6]
  76. best_wv_idx = i+1
  77. if dev_acc[0][7] > best_wr:
  78. best_wr = dev_acc[0][7]
  79. best_wr_idx = i+1
  80. print 'Train loss = %.3f' % train_loss
  81. print 'Dev Logic Form: %.3f' % dev_acc[1]
  82. print 'Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx)
  83. if (i+1) % 10 == 0:
  84. print 'Best val acc: %s\nOn epoch individually %s'%(
  85. (best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
  86. (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx))