From cdf8400c952e326ef3581ecb7c2e7f343b253b27 Mon Sep 17 00:00:00 2001 From: Tong Guo <779222056@qq.com> Date: Wed, 2 May 2018 10:44:53 +0800 Subject: [PATCH 1/2] Update sqlnet_condition_predict.py --- .../model/modules/sqlnet_condition_predict.py | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/sqlnet/model/modules/sqlnet_condition_predict.py b/sqlnet/model/modules/sqlnet_condition_predict.py index 1eb5500..1a01138 100644 --- a/sqlnet/model/modules/sqlnet_condition_predict.py +++ b/sqlnet/model/modules/sqlnet_condition_predict.py @@ -40,7 +40,7 @@ def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, use_ca, gpu): self.cond_col_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2, num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True) - self.cond_col_out_K = nn.Linear(N_h, N_h) + self.cond_col_out_K = nn.Linear(N_h*2, N_h) self.cond_col_out_col = nn.Linear(N_h, N_h) self.cond_col_out = nn.Sequential(nn.ReLU(), nn.Linear(N_h, 1)) @@ -51,7 +51,7 @@ def __init__(self, N_word, N_h, N_depth, max_col_num, max_tok_num, use_ca, gpu): self.cond_op_att = nn.Linear(N_h, N_h) else: self.cond_op_att = nn.Linear(N_h, 1) - self.cond_op_out_K = nn.Linear(N_h, N_h) + self.cond_op_out_K = nn.Linear(N_h*2, N_h) self.cond_op_name_enc = nn.LSTM(input_size=N_word, hidden_size=N_h/2, num_layers=N_depth, batch_first=True, dropout=0.3, bidirectional=True) @@ -155,6 +155,12 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, col_att = self.softmax(col_att_val.view( (-1, max_x_len))).view(B, -1, max_x_len) K_cond_col = (h_col_enc.unsqueeze(1) * col_att.unsqueeze(3)).sum(2) + + # bi-attention + temp, _ = torch.max(col_att_val, dim=1) + temp_probs = self.softmax(temp).unsqueeze(2) + temp2 = (temp_probs*h_col_enc).sum(1) + temp2 = temp2.unsqueeze(1).expand([e_cond_col.size()[0],e_cond_col.size()[1],e_cond_col.size()[2]]) else: col_att_val = self.cond_col_att(h_col_enc).squeeze() for idx, num in enumerate(x_len): @@ -164,8 +170,9 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, K_cond_col = (h_col_enc * col_att_val.unsqueeze(2)).sum(1).unsqueeze(1) - cond_col_score = self.cond_col_out(self.cond_col_out_K(K_cond_col) + - self.cond_col_out_col(e_cond_col)).squeeze() + cond_col_score = self.cond_col_out(self.cond_col_out_K( + torch.cat([K_cond_col,temp2*e_cond_col],dim=-1)) + + self.cond_col_out_col(e_cond_col)).squeeze() # B,6 max_col_num = max(col_num) for b, num in enumerate(col_num): if num < max_col_num: @@ -174,10 +181,10 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, #Predict the operator of conditions chosen_col_gt = [] if gt_cond is None: - cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) - col_scores = cond_col_score.data.cpu().numpy() + cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) # B + col_scores = cond_col_score.data.cpu().numpy() # B,6 chosen_col_gt = [list(np.argsort(-col_scores[b])[:cond_nums[b]]) - for b in range(len(cond_nums))] + for b in range(len(cond_nums))] # B [] else: chosen_col_gt = [ [x[0] for x in one_gt_cond] for one_gt_cond in gt_cond] @@ -190,7 +197,7 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4) col_emb.append(cur_col_emb) - col_emb = torch.stack(col_emb) + col_emb = torch.stack(col_emb) # B,4,100 h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len) if self.use_ca: @@ -201,6 +208,13 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, op_att_val[idx, :, num:] = -100 op_att = self.softmax(op_att_val.view(B*4, -1)).view(B, 4, -1) K_cond_op = (h_op_enc.unsqueeze(1) * op_att.unsqueeze(3)).sum(2) + + # bi-attention + temp, _ = torch.max(op_att_val, dim=1) + temp_probs = self.softmax(temp).unsqueeze(2) + temp2 = (temp_probs * h_op_enc).sum(1) + temp2 = temp2.unsqueeze(1).expand([col_emb.size()[0], col_emb.size()[1], col_emb.size()[2]]) + else: op_att_val = self.cond_op_att(h_op_enc).squeeze() for idx, num in enumerate(x_len): @@ -209,8 +223,9 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, op_att = self.softmax(op_att_val) K_cond_op = (h_op_enc * op_att.unsqueeze(2)).sum(1).unsqueeze(1) - cond_op_score = self.cond_op_out(self.cond_op_out_K(K_cond_op) + - self.cond_op_out_col(col_emb)).squeeze() + cond_op_score = self.cond_op_out(self.cond_op_out_K( + torch.cat([K_cond_op,temp2*col_emb],dim=-1) + ) + self.cond_op_out_col(col_emb)).squeeze() #Predict the string of conditions h_str_enc, _ = run_lstm(self.cond_str_lstm, x_emb_var, x_len) From fbe2222d2555bfce44e45d460ebb267bda7c458f Mon Sep 17 00:00:00 2001 From: Tong Guo <779222056@qq.com> Date: Wed, 2 May 2018 10:45:31 +0800 Subject: [PATCH 2/2] Update sqlnet_condition_predict.py --- sqlnet/model/modules/sqlnet_condition_predict.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sqlnet/model/modules/sqlnet_condition_predict.py b/sqlnet/model/modules/sqlnet_condition_predict.py index 1a01138..52d5dcd 100644 --- a/sqlnet/model/modules/sqlnet_condition_predict.py +++ b/sqlnet/model/modules/sqlnet_condition_predict.py @@ -172,7 +172,7 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, cond_col_score = self.cond_col_out(self.cond_col_out_K( torch.cat([K_cond_col,temp2*e_cond_col],dim=-1)) + - self.cond_col_out_col(e_cond_col)).squeeze() # B,6 + self.cond_col_out_col(e_cond_col)).squeeze() max_col_num = max(col_num) for b, num in enumerate(col_num): if num < max_col_num: @@ -181,10 +181,10 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, #Predict the operator of conditions chosen_col_gt = [] if gt_cond is None: - cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) # B - col_scores = cond_col_score.data.cpu().numpy() # B,6 + cond_nums = np.argmax(cond_num_score.data.cpu().numpy(), axis=1) + col_scores = cond_col_score.data.cpu().numpy() chosen_col_gt = [list(np.argsort(-col_scores[b])[:cond_nums[b]]) - for b in range(len(cond_nums))] # B [] + for b in range(len(cond_nums))] else: chosen_col_gt = [ [x[0] for x in one_gt_cond] for one_gt_cond in gt_cond] @@ -197,7 +197,7 @@ def forward(self, x_emb_var, x_len, col_inp_var, col_name_len, for x in chosen_col_gt[b]] + [e_cond_col[b, 0]] * (4 - len(chosen_col_gt[b]))) # Pad the columns to maximum (4) col_emb.append(cur_col_emb) - col_emb = torch.stack(col_emb) # B,4,100 + col_emb = torch.stack(col_emb) h_op_enc, _ = run_lstm(self.cond_op_lstm, x_emb_var, x_len) if self.use_ca: