|
@@ -30,13 +30,13 @@ if __name__ == '__main__':
|
|
# load dataset
|
|
# load dataset
|
|
train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
|
|
train_sql, train_table, train_db, dev_sql, dev_table, dev_db = load_dataset(use_small=use_small)
|
|
|
|
|
|
- word_emb = load_word_emb('data/char_embedding')
|
|
|
|
|
|
+ word_emb = load_word_emb('data/char_embedding.json')
|
|
model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
|
|
model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
|
|
|
|
|
|
if args.restore:
|
|
if args.restore:
|
|
model_path= 'saved_model/best_model'
|
|
model_path= 'saved_model/best_model'
|
|
- print "Loading trained model from %s" % model_path
|
|
|
|
|
|
+ print ("Loading trained model from %s" % model_path)
|
|
model.load_state_dict(torch.load(model_path))
|
|
model.load_state_dict(torch.load(model_path))
|
|
|
|
|
|
# used to record best score of each sub-task
|
|
# used to record best score of each sub-task
|
|
@@ -45,16 +45,16 @@ if __name__ == '__main__':
|
|
best_lf, best_lf_idx = 0.0, 0
|
|
best_lf, best_lf_idx = 0.0, 0
|
|
best_ex, best_ex_idx = 0.0, 0
|
|
best_ex, best_ex_idx = 0.0, 0
|
|
|
|
|
|
- print "#"*20+" Star to Train " + "#"*20
|
|
|
|
|
|
+ print ("#"*20+" Star to Train " + "#"*20)
|
|
for i in range(args.epoch):
|
|
for i in range(args.epoch):
|
|
- print 'Epoch %d'%(i+1)
|
|
|
|
|
|
+ print ('Epoch %d'%(i+1))
|
|
# train on the train dataset
|
|
# train on the train dataset
|
|
train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
|
|
train_loss = epoch_train(model, optimizer, batch_size, train_sql, train_table)
|
|
# evaluate on the dev dataset
|
|
# evaluate on the dev dataset
|
|
dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
|
|
dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
|
|
# accuracy of each sub-task
|
|
# accuracy of each sub-task
|
|
- print 'Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
|
|
|
|
- dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7])
|
|
|
|
|
|
+ print ('Sel-Num: %.3f, Sel-Col: %.3f, Sel-Agg: %.3f, W-Num: %.3f, W-Col: %.3f, W-Op: %.3f, W-Val: %.3f, W-Rel: %.3f'%(
|
|
|
|
+ dev_acc[0][0], dev_acc[0][1], dev_acc[0][2], dev_acc[0][3], dev_acc[0][4], dev_acc[0][5], dev_acc[0][6], dev_acc[0][7]))
|
|
# save the best model
|
|
# save the best model
|
|
if dev_acc[1] > best_lf:
|
|
if dev_acc[1] > best_lf:
|
|
best_lf = dev_acc[1]
|
|
best_lf = dev_acc[1]
|
|
@@ -90,11 +90,11 @@ if __name__ == '__main__':
|
|
if dev_acc[0][7] > best_wr:
|
|
if dev_acc[0][7] > best_wr:
|
|
best_wr = dev_acc[0][7]
|
|
best_wr = dev_acc[0][7]
|
|
best_wr_idx = i+1
|
|
best_wr_idx = i+1
|
|
- print 'Train loss = %.3f' % train_loss
|
|
|
|
- print 'Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2])
|
|
|
|
- print 'Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx)
|
|
|
|
- print 'Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx)
|
|
|
|
|
|
+ print ('Train loss = %.3f' % train_loss)
|
|
|
|
+ print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
|
|
|
|
+ print ('Best Logic Form: %.3f at epoch %d' % (best_lf, best_lf_idx))
|
|
|
|
+ print ('Best Execution: %.3f at epoch %d' % (best_ex, best_ex_idx))
|
|
if (i+1) % 10 == 0:
|
|
if (i+1) % 10 == 0:
|
|
- print 'Best val acc: %s\nOn epoch individually %s'%(
|
|
|
|
|
|
+ print ('Best val acc: %s\nOn epoch individually %s'%(
|
|
(best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
|
|
(best_sn, best_sc, best_sa, best_wn, best_wc, best_wo, best_wv),
|
|
- (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx))
|
|
|
|
|
|
+ (best_sn_idx, best_sc_idx, best_sa_idx, best_wn_idx, best_wc_idx, best_wo_idx, best_wv_idx)))
|