Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into sft
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Oct 28, 2024
2 parents a7459e0 + 331c0aa commit dde75ac
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions src/levanter/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class LlamaConfig(HFCompatConfig):
activation_function: str = "silu"
initializer_range: float = 0.02
layer_norm_epsilon: float = 1e-5
tie_word_embeddings: bool = False

# Attention-related config
upcast_attn: bool = False
Expand Down Expand Up @@ -120,6 +121,7 @@ def from_hf_config(cls, hf_config: HfConfig):
activation_function=hf_config.hidden_act,
initializer_range=hf_config.initializer_range,
layer_norm_epsilon=hf_config.rms_norm_eps,
tie_word_embeddings=hf_config.tie_word_embeddings,
rope=rope_config,
)

Expand Down Expand Up @@ -148,6 +150,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None)
hidden_act=self.activation_function,
initializer_range=self.initializer_range,
rms_norm_eps=self.layer_norm_epsilon,
tie_word_embeddings=self.tie_word_embeddings,
# rope_scaling=self.rope_scaling,
vocab_size=vocab_size,
rope_theta=rope_theta,
Expand Down Expand Up @@ -504,7 +507,7 @@ def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None):
class LlamaLMHeadModel(eqx.Module, LmHeadModel[LlamaConfig], StateDictSerializationMixin):
transformer: LlamaTransformer
embeddings: LlamaEmbedding
lm_head: hnn.Linear
lm_head: Optional[hnn.Linear]

@property
def config(self):
Expand All @@ -523,7 +526,11 @@ def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel":
k_t, k_emb = jrandom.split(key, 2)
transformer = LlamaTransformer.init(config, key=k_t)
embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb)
lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True)
if config.tie_word_embeddings:
lm_head = None
else:
lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True)

return LlamaLMHeadModel(transformer, embeddings, lm_head)

def __call__(
Expand All @@ -544,38 +551,45 @@ def __call__(
k_t, k_head = maybe_rng_split(key, 2)
x = self.embeddings.embed(input_ids)
x = self.transformer(x, attn_mask=attn_mask, key=k_t)
lm_logits = self.lm_head(x, key=k_head)
if self.lm_head:
lm_logits = self.lm_head(x, key=k_head)
else:
lm_logits = self.embeddings.unembed(x)
return lm_logits

def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]":
new_Vocab = self.Vocab.resize(new_size)
k1, k2 = maybe_rng_split(key, 2)
new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1)
new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2)
new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix)

return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head)
if self.lm_head is not None:
new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2)
new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix)
return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head)
else:
return dataclasses.replace(self, embeddings=new_embeddings)

def _state_dict_key_map(self) -> Dict[str, Optional[str]]:
return {"transformer": "model", "embeddings": None}

def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None):
# unflatten the linear layers of HF state_dict to match the shape of LlamaMlp
d = state_dict.copy()
d.update(
unflatten_linear_layers(
apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True
if self.lm_head is not None:
d.update(
unflatten_linear_layers(
apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True
)
)
)
return super().from_state_dict(d, prefix)

def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict:
my_dict: StateDict = {}
super().update_state_dict(my_dict, prefix=prefix)

my_dict.update(
flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True)
)
if self.lm_head is not None:
my_dict.update(
flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True)
)

state_dict.update(my_dict)
return state_dict

0 comments on commit dde75ac

Please sign in to comment.