From 62d92dd0076b8f874e5c1d71a6867466a2c2e539 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 24 Oct 2024 20:15:38 -0400 Subject: [PATCH 1/2] Support Tied Weights in Llama Models --- src/levanter/models/llama.py | 38 +++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index e777b7636..8bc25c7dd 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -64,6 +64,7 @@ class LlamaConfig(HFCompatConfig): activation_function: str = "silu" initializer_range: float = 0.02 layer_norm_epsilon: float = 1e-5 + tie_word_embeddings: bool = False # Attention-related config upcast_attn: bool = False @@ -120,6 +121,7 @@ def from_hf_config(cls, hf_config: HfConfig): activation_function=hf_config.hidden_act, initializer_range=hf_config.initializer_range, layer_norm_epsilon=hf_config.rms_norm_eps, + tie_word_embeddings=hf_config.tie_word_embeddings, rope=rope_config, ) @@ -148,6 +150,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) hidden_act=self.activation_function, initializer_range=self.initializer_range, rms_norm_eps=self.layer_norm_epsilon, + tie_word_embeddings=self.tie_word_embeddings, # rope_scaling=self.rope_scaling, vocab_size=vocab_size, rope_theta=rope_theta, @@ -504,7 +507,7 @@ def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): class LlamaLMHeadModel(eqx.Module, LmHeadModel[LlamaConfig], StateDictSerializationMixin): transformer: LlamaTransformer embeddings: LlamaEmbedding - lm_head: hnn.Linear + lm_head: Optional[hnn.Linear] @property def config(self): @@ -523,7 +526,11 @@ def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel": k_t, k_emb = jrandom.split(key, 2) transformer = LlamaTransformer.init(config, key=k_t) embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) - lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) + if config.tie_word_embeddings: + lm_head = None + else: + lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) + return LlamaLMHeadModel(transformer, embeddings, lm_head) def __call__( @@ -544,7 +551,10 @@ def __call__( k_t, k_head = maybe_rng_split(key, 2) x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, key=k_t) - lm_logits = self.lm_head(x, key=k_head) + if tie_word_embeddings: + lm_logits = self.embeddings.unembed(x) + else: + lm_logits = self.lm_head(x, key=k_head) return lm_logits def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": @@ -552,7 +562,11 @@ def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": k1, k2 = maybe_rng_split(key, 2) new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) - new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + new_lm_head = ( + dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + if not config.tie_word_embeddings + else None + ) return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) @@ -562,20 +576,22 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp d = state_dict.copy() - d.update( - unflatten_linear_layers( - apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True + if not config.tie_word_embeddings: + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True + ) ) - ) return super().from_state_dict(d, prefix) def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: my_dict: StateDict = {} super().update_state_dict(my_dict, prefix=prefix) - my_dict.update( - flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) - ) + if not config.tie_word_embeddings: + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) + ) state_dict.update(my_dict) return state_dict From d88f1f2b4bf42511cd7bfc7aab06793a04a681f3 Mon Sep 17 00:00:00 2001 From: Helw150 Date: Thu, 24 Oct 2024 20:24:11 -0400 Subject: [PATCH 2/2] Fix Pre-Commit --- src/levanter/models/llama.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index 8bc25c7dd..1e09ffbc5 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -551,24 +551,22 @@ def __call__( k_t, k_head = maybe_rng_split(key, 2) x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, key=k_t) - if tie_word_embeddings: - lm_logits = self.embeddings.unembed(x) - else: + if self.lm_head: lm_logits = self.lm_head(x, key=k_head) + else: + lm_logits = self.embeddings.unembed(x) return lm_logits def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": new_Vocab = self.Vocab.resize(new_size) k1, k2 = maybe_rng_split(key, 2) new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) - new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) - new_lm_head = ( - dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) - if not config.tie_word_embeddings - else None - ) - - return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + if self.lm_head is not None: + new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) + new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + else: + return dataclasses.replace(self, embeddings=new_embeddings) def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"transformer": "model", "embeddings": None} @@ -576,7 +574,7 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp d = state_dict.copy() - if not config.tie_word_embeddings: + if self.lm_head is not None: d.update( unflatten_linear_layers( apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True @@ -588,7 +586,7 @@ def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) my_dict: StateDict = {} super().update_state_dict(my_dict, prefix=prefix) - if not config.tie_word_embeddings: + if self.lm_head is not None: my_dict.update( flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) )