|
@@ -174,8 +174,8 @@ class SQLNet(nn.Module):
|
|
sel_agg_truth_var = Variable(data.cuda())
|
|
sel_agg_truth_var = Variable(data.cuda())
|
|
else:
|
|
else:
|
|
sel_agg_truth_var = Variable(data)
|
|
sel_agg_truth_var = Variable(data)
|
|
- sel_agg_pred = agg_score[b, :len(truth_num[b][1])]
|
|
|
|
- loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)
|
|
|
|
|
|
+ sel_agg_pred = agg_score[b, :len(truth_num[b][1])]
|
|
|
|
+ loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)
|
|
|
|
|
|
cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score
|
|
cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score
|
|
|
|
|