diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index e777b7636..1e09ffbc5 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -64,6 +64,7 @@ class LlamaConfig(HFCompatConfig): activation_function: str = "silu" initializer_range: float = 0.02 layer_norm_epsilon: float = 1e-5 + tie_word_embeddings: bool = False # Attention-related config upcast_attn: bool = False @@ -120,6 +121,7 @@ def from_hf_config(cls, hf_config: HfConfig): activation_function=hf_config.hidden_act, initializer_range=hf_config.initializer_range, layer_norm_epsilon=hf_config.rms_norm_eps, + tie_word_embeddings=hf_config.tie_word_embeddings, rope=rope_config, ) @@ -148,6 +150,7 @@ def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) hidden_act=self.activation_function, initializer_range=self.initializer_range, rms_norm_eps=self.layer_norm_epsilon, + tie_word_embeddings=self.tie_word_embeddings, # rope_scaling=self.rope_scaling, vocab_size=vocab_size, rope_theta=rope_theta, @@ -504,7 +507,7 @@ def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): class LlamaLMHeadModel(eqx.Module, LmHeadModel[LlamaConfig], StateDictSerializationMixin): transformer: LlamaTransformer embeddings: LlamaEmbedding - lm_head: hnn.Linear + lm_head: Optional[hnn.Linear] @property def config(self): @@ -523,7 +526,11 @@ def init(cls, Vocab: Axis, config: LlamaConfig, *, key) -> "LlamaLMHeadModel": k_t, k_emb = jrandom.split(key, 2) transformer = LlamaTransformer.init(config, key=k_t) embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) - lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) + if config.tie_word_embeddings: + lm_head = None + else: + lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) + return LlamaLMHeadModel(transformer, embeddings, lm_head) def __call__( @@ -544,17 +551,22 @@ def __call__( k_t, k_head = maybe_rng_split(key, 2) x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, key=k_t) - lm_logits = self.lm_head(x, key=k_head) + if self.lm_head: + lm_logits = self.lm_head(x, key=k_head) + else: + lm_logits = self.embeddings.unembed(x) return lm_logits def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[LlamaConfig]": new_Vocab = self.Vocab.resize(new_size) k1, k2 = maybe_rng_split(key, 2) new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) - new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) - new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) - - return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + if self.lm_head is not None: + new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) + new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + else: + return dataclasses.replace(self, embeddings=new_embeddings) def _state_dict_key_map(self) -> Dict[str, Optional[str]]: return {"transformer": "model", "embeddings": None} @@ -562,20 +574,22 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]: def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): # unflatten the linear layers of HF state_dict to match the shape of LlamaMlp d = state_dict.copy() - d.update( - unflatten_linear_layers( - apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True + if self.lm_head is not None: + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "lm_head"), state_dict, self.lm_head, out_dims_first_in_dict=True + ) ) - ) return super().from_state_dict(d, prefix) def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: my_dict: StateDict = {} super().update_state_dict(my_dict, prefix=prefix) - my_dict.update( - flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) - ) + if self.lm_head is not None: + my_dict.update( + flatten_linear_layers(apply_prefix(prefix, "lm_head"), self.lm_head, out_dims_first_in_dict=True) + ) state_dict.update(my_dict) return state_dict