You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to implement column attention and here is my code.. it would be great if you could review it and tell me if its not correct. Thanks
I am implementing this to find column for select :
n_h = 128 # number of hidden units
question_input = Input(shape=(max_len,),name='Question_input')
column_input = Input(shape=(max_len,),name='Column_input')
HI, Has anyone tried sqlnet keras implementation?
I am trying to implement column attention and here is my code.. it would be great if you could review it and tell me if its not correct. Thanks
I am implementing this to find column for select :
n_h = 128 # number of hidden units
question_input = Input(shape=(max_len,),name='Question_input')
column_input = Input(shape=(max_len,),name='Column_input')
embedding= Embedding(max_token_index, n_h, input_length=max_len,name='embedding')
Q_embedding= embedding(question_input)
C_embedding= embedding(column_input)
encoder_question = Bidirectional(LSTM(n_h, return_state=True, return_sequences=True))
Q_enc , Q_state_h1, Q_state_h2 = encoder_question(Q_embedding)
encoder_column = Bidirectional(LSTM(n_h, return_state=True, return_sequences=True))
C_enc , C_state_h1, C_state_h2 = encoder_column(C_embedding)
########## Column Attention Code ########
Q_num_att = Dense(max_len,activation='relu')(Q_enc)
Q_self = Dense(max_len,activation='relu')(Q_num_att)
att_val_qc_num = Concatenate()([Q_self,C_enc])
att_prob_qc_num = Dense(maxlen,activation='softmax')(att_val_qc_num)
q_weighted_num = (Q_enc * att_prob_qc_num).sum(axis=0, keepdims=True)
########## Column Attention Code ############
col_num_out_q = Dense(max_len,activation='relu')(q_weighted_num)
col_num_out = Dense(max_len,activation='tanh')(col_num_out_q)
#con=Concatenate()([Q_state_h1,Q_state_h2,C_state_h1,C_state_h2])
final=Dense(6,activation='softmax')(col_num_out)
model = Model([question_input, column_input], final)
model.summary()
Please correct me if i am wrong.
The text was updated successfully, but these errors were encountered: