|
@@ -1,3 +1,6 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# -*- encoding: utf-8 -*-
|
|
|
+
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
@@ -10,7 +13,7 @@ from sqlnet.model.modules.sqlnet_condition_predict import SQLNetCondPredictor
|
|
|
from sqlnet.model.modules.select_number import SelNumPredictor
|
|
|
from sqlnet.model.modules.where_relation import WhereRelationPredictor
|
|
|
|
|
|
-# 定义SQLNet模型
|
|
|
+# 定义 SQLNet 模型
|
|
|
class SQLNet(nn.Module):
|
|
|
def __init__(self, word_emb, N_word, N_h=100, N_depth=2,
|
|
|
gpu=False, use_ca=True, trainable_emb=False):
|