Skip to content

Commit

Permalink
Rename layers
Browse files Browse the repository at this point in the history
  • Loading branch information
luozhouyang committed Nov 23, 2021
1 parent 16b0d43 commit 5910762
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions transformers_keras/question_answering/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def __init__(

sequence_output, _, _, _ = bert_model(input_ids, segment_ids, attention_mask)
logits = tf.keras.layers.Dense(num_labels, name="dense")(sequence_output)
head_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 0], name="head")(logits)
tail_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 1], name="tail")(logits)
super().__init__(inputs=[input_ids, segment_ids, attention_mask], outputs=[head_logits, tail_logits], **kwargs)
start_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 0], name="start")(logits)
end_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 1], name="end")(logits)
super().__init__(inputs=[input_ids, segment_ids, attention_mask], outputs=[start_logits, end_logits], **kwargs)


class AlbertForQuestionAnswering(AlbertPretrainedModel):
Expand Down Expand Up @@ -98,9 +98,9 @@ def __init__(

sequence_output, _, _, _ = albert_model(input_ids, segment_ids, attention_mask)
logits = tf.keras.layers.Dense(num_labels, name="dense")(sequence_output)
head_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 0], name="head")(logits)
tail_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 1], name="tail")(logits)
super().__init__(inputs=[input_ids, segment_ids, attention_mask], outputs=[head_logits, tail_logits], **kwargs)
start_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 0], name="start")(logits)
end_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 1], name="end")(logits)
super().__init__(inputs=[input_ids, segment_ids, attention_mask], outputs=[start_logits, end_logits], **kwargs)


class BertForQuestionAnsweringX(BertPretrainedModel):
Expand Down Expand Up @@ -146,12 +146,12 @@ def __init__(

sequence_output, pooled_output, _, _ = bert_model(input_ids, segment_ids, attention_mask)
logits = tf.keras.layers.Dense(num_labels, name="dense")(sequence_output)
head_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 0], name="head")(logits)
tail_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 1], name="tail")(logits)
start_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 0], name="start")(logits)
end_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 1], name="end")(logits)
class_logits = tf.keras.layers.Dense(num_classes, name="class")(pooled_output)

super().__init__(
inputs=[input_ids, segment_ids, attention_mask], outputs=[head_logits, tail_logits, class_logits], **kwargs
inputs=[input_ids, segment_ids, attention_mask], outputs=[start_logits, end_logits, class_logits], **kwargs
)


Expand Down Expand Up @@ -204,9 +204,9 @@ def __init__(

sequence_output, pooled_output, _, _ = albert_model(input_ids, segment_ids, attention_mask)
logits = tf.keras.layers.Dense(num_labels, name="dense")(sequence_output)
head_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 0], name="head")(logits)
tail_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 1], name="tail")(logits)
start_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 0], name="start")(logits)
end_logits = tf.keras.layers.Lambda(lambda x: x[:, :, 1], name="end")(logits)
class_logits = tf.keras.layers.Dense(num_classes, name="class")(pooled_output)
super().__init__(
inputs=[input_ids, segment_ids, attention_mask], outputs=[head_logits, tail_logits, class_logits], **kwargs
inputs=[input_ids, segment_ids, attention_mask], outputs=[start_logits, end_logits, class_logits], **kwargs
)

0 comments on commit 5910762

Please sign in to comment.