diff --git a/sqlnet/model/modules/sqlnet_condition_predict.py b/sqlnet/model/modules/sqlnet_condition_predict.py index 1eb5500..52d5dcd 100644 --- a/sqlnet/model/modules/sqlnet_condition_predict.py +++ b/sqlnet/model/modules/sqlnet_condition_predict.py @@ -40,7 +40,7 @@ def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, use_ca, gpu): self.cond_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2, num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True) - self.cond_col_out_K = nn.Linear(N_h, N_h) + self.cond_col_out_K = nn.Linear(N_h*2, N_h) self.cond_col_out_col = nn.Linear(N_h, N_h) self.cond_col_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1)) @@ -51,7 +51,7 @@ def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, use_ca, gpu): self.cond_op_att = nn.Linear(N_h, N_h) else: self.cond_op_att = nn.Linear(N_h, 1) - self.cond_op_out_K = nn.Linear(N_h, N_h) + self.cond_op_out_K = nn.Linear(N_h*2, N_h) self.cond_op_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2, num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True) @@ -155,6 +155,12 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_att = self.softmax(col_att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2) + + # bi-attention + temp, _ = torch.max(col_att_val, dim=1) + temp_probs = self.softmax(temp).unsqueeze(2) + temp2 = (temp_probs*h_col_enc).sum(1) + temp2 = temp2.unsqueeze(1).expand([e_cond_col.size()[0],e_cond_col.size()[1],e_cond_col.size()[2]]) else: col_att_val = self.cond_col_att(h_col_enc).squeeze() for idx, num in enumerate(x_len): @@ -164,7 +170,8 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, K_cond_col = (h_col_enc * col_att_val.unsqueeze(2)).sum(1).unsqueeze(1) - cond_col_score = self.cond_col_out(self.cond_col_out_K(K_cond_col) + + cond_col_score = self.cond_col_out(self.cond_col_out_K( + torch.cat([K_cond_col,temp2*e_cond_col],dim=-1)) + self.cond_col_out_col(e_cond_col)).squeeze() max_col_num = max(col_num) for b, num in enumerate(col_num): @@ -201,6 +208,13 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, op_att_val[idx, :, num:] = -100 op_att = self.softmax(op_att_val.view(B*4, -1)).view(B, 4, -1) K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) + + # bi-attention + temp, _ = torch.max(op_att_val, dim=1) + temp_probs = self.softmax(temp).unsqueeze(2) + temp2 = (temp_probs * h_op_enc).sum(1) + temp2 = temp2.unsqueeze(1).expand([col_emb.size()[0], col_emb.size()[1], col_emb.size()[2]]) + else: op_att_val = self.cond_op_att(h_op_enc).squeeze() for idx, num in enumerate(x_len): @@ -209,8 +223,9 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, op_att = self.softmax(op_att_val) K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1) - cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) + - self.cond_op_out_col(col_emb)).squeeze() + cond_op_score = self.cond_op_out(self.cond_op_out_K( + torch.cat([K_cond_op,temp2*col_emb],dim=-1) + ) + self.cond_op_out_col(col_emb)).squeeze() #Predict the string of conditions h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len)