test.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #!/usr/bin/env python
  2. # -*- encoding: utf-8 -*-
  3. '''
  4. @File : test.py
  5. @Time : 2019/07/07 23:41:48
  6. @Author : Liuyuqi
  7. @Version : 1.0
  8. @Contact : liuyuqi.gov@msn.cn
  9. @License : (C)Copyright 2019
  10. @Desc : 测试代码
  11. '''
  12. import torch
  13. from sqlnet.utils import *
  14. from sqlnet.model.sqlnet import SQLNet
  15. import argparse
  16. if __name__ == '__main__':
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument('--gpu', action='store_true', help='Whether use gpu')
  19. parser.add_argument('--toy', action='store_true', help='Small batchsize for fast debugging.')
  20. parser.add_argument('--ca', action='store_true', help='Whether use column attention.')
  21. parser.add_argument('--train_emb', action='store_true', help='Use trained word embedding for SQLNet.')
  22. parser.add_argument('--output_dir', type=str, default='', help='Output path of prediction result')
  23. args = parser.parse_args()
  24. n_word=300
  25. if args.toy:
  26. use_small=True
  27. gpu=args.gpu
  28. batch_size=16
  29. else:
  30. use_small=False
  31. gpu=args.gpu
  32. batch_size=64
  33. dev_sql, dev_table, dev_db, test_sql, test_table, test_db = load_dataset(use_small=use_small, mode='test')
  34. word_emb = load_word_emb('data/char_embedding.json')
  35. model = SQLNet(word_emb, N_word=n_word, use_ca=args.ca, gpu=gpu, trainable_emb=args.train_emb)
  36. model_path = 'saved_model/best_model'
  37. print ("Loading from %s" % model_path)
  38. model.load_state_dict(torch.load(model_path))
  39. print ("Loaded model from %s" % model_path)
  40. dev_acc = epoch_acc(model, batch_size, dev_sql, dev_table, dev_db)
  41. print ('Dev Logic Form Accuracy: %.3f, Execution Accuracy: %.3f' % (dev_acc[1], dev_acc[2]))
  42. print ("Start to predict test set")
  43. predict_test(model, batch_size, test_sql, test_table, args.output_dir)
  44. print ("Output path of prediction result is %s" % args.output_dir)