Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
hzjane committed Jun 6, 2024
1 parent 09c6780 commit c12942f
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def __init__(self, config: LlamaConfig):

self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.pp_config.is_head:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
if self.pp_config.is_tail:
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)


def get_input_embeddings(self):
Expand Down Expand Up @@ -259,7 +261,6 @@ def forward(
if self.pp_config.is_tail:
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
Expand Down

0 comments on commit c12942f

Please sign in to comment.