Browse Source

fix agg loss bug

waynesun 4 years ago
parent
commit
2e267cf7a1
1 changed files with 2 additions and 2 deletions
  1. 2 2
      sqlnet/model/sqlnet.py

+ 2 - 2
sqlnet/model/sqlnet.py

@@ -173,8 +173,8 @@ class SQLNet(nn.Module):
                 sel_agg_truth_var = Variable(data.cuda())
             else:
                 sel_agg_truth_var = Variable(data)
-        sel_agg_pred = agg_score[b, :len(truth_num[b][1])]
-        loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)
+            sel_agg_pred = agg_score[b, :len(truth_num[b][1])]
+            loss += (self.CE(sel_agg_pred, sel_agg_truth_var)) / len(truth_num)
 
         cond_num_score, cond_col_score, cond_op_score, cond_str_score = cond_score