diff --git a/src/levanter/models/diva.py b/src/levanter/models/diva.py index 49bea6589..323a82ce2 100644 --- a/src/levanter/models/diva.py +++ b/src/levanter/models/diva.py @@ -119,6 +119,7 @@ class DivaConfig(HFCompatConfig, ASRConfig): ) prefix = property(lambda self: hax.named(self.pre_audio_prompt, axis="position")) suffix = property(lambda self: hax.named(self.pre_text_prompt, axis="position")) + Embed = property(lambda self: self.dec_config.Embed) Pos = property(lambda self: Axis(name="position", size=self.max_length)) AudioPos = property(lambda self: self.enc_config.AudioPos) KeyPos = property(lambda self: self.Pos.alias("key_position")) @@ -204,12 +205,14 @@ def init( ) if init_from_submodels: - llm: Union[LlamaLMHeadModel | MistralLMHeadModel | GemmaLMHeadModel] = HFCheckpointConverter( - type(config.dec_config), config.reference_decoder - ).load_pretrained( - config.dec_config.model_type, - config.reference_decoder, - config.dec_config, + llm: Union[LlamaLMHeadModel | MistralLMHeadModel | GemmaLMHeadModel] = ( + HFCheckpointConverter(type(config.dec_config), config.reference_decoder) + .load_pretrained( + config.dec_config.model_type, + config.reference_decoder, + config.dec_config, + ) + .resize_vocab(Vocab.size) ) # type: ignore[assignment] whisper: WhisperModel = HFCheckpointConverter( WhisperConfig, config.reference_encoder, ignore_prefix="model" @@ -219,8 +222,7 @@ def init( config.enc_config, ) # type: ignore[assignment] encoder = whisper.encoder - # connector = whisper.decoder - connector = WhisperDecoder.init(config.enc_config, key=k_connector) + connector = whisper.decoder decoder = llm mean_embedding = hax.mean(llm.embeddings.token_embeddings.weight, llm.embeddings.Vocab) projection = dataclasses.replace( diff --git a/src/levanter/models/qwen.py b/src/levanter/models/qwen.py index 807a768ad..604d270e4 100644 --- a/src/levanter/models/qwen.py +++ b/src/levanter/models/qwen.py @@ -15,7 +15,14 @@ from levanter.compat.hf_checkpoints import HFCheckpointConverter from levanter.logging import silence_transformer_nag from levanter.models.attention import AttentionMask, dot_product_attention -from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaMlp, LlamaRMSNorm, LlamaTransformer +from levanter.models.llama import ( + LlamaConfig, + LlamaEmbedding, + LlamaMlp, + LlamaRMSNorm, + LlamaTransformer, + LlamaLMHeadModel, +) from levanter.models.lm_model import LmConfig, LmHeadModel from levanter.models.rotary import RotaryEmbeddingsConfig from levanter.types import BlockFoldable @@ -264,7 +271,7 @@ def init(config: QwenConfig, *, key) -> "QwenTransformer": # Modified LM head model for Qwen -class QwenLMHeadModel(LmHeadModel[QwenConfig], ModuleWithStateDictSerialization): +class QwenLMHeadModel(LlamaLMHeadModel, ModuleWithStateDictSerialization): transformer: QwenTransformer embeddings: LlamaEmbedding # Can reuse Llama embeddings lm_head: Optional[hnn.Linear] diff --git a/src/levanter/models/whisper.py b/src/levanter/models/whisper.py index ad768fbe1..a3d1ff5c0 100644 --- a/src/levanter/models/whisper.py +++ b/src/levanter/models/whisper.py @@ -131,8 +131,8 @@ class WhisperMlp(eqx.Module): @staticmethod def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = True) -> "WhisperMlp": k_fc, k_proj = haliax.jax_utils.maybe_rng_split(key, 2) - fc1 = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias, out_first=False) - fc2 = hnn.Linear.init(Out=Embed, In=Mlp, key=k_proj, use_bias=use_bias, out_first=False) + fc1 = hnn.Linear.init(Out=Mlp, In=Embed, key=k_fc, use_bias=use_bias) + fc2 = hnn.Linear.init(Out=Embed, In=Mlp, key=k_proj, use_bias=use_bias) if isinstance(activation_fn, str): activation_fn = ACT2FN[activation_fn] act = activation_fn # type: ignore @@ -164,10 +164,30 @@ def init(Heads: Axis, HeadSize: Axis, config: WhisperConfig, *, key) -> "Whisper Embed = config.Embed k_q, k_k, k_v, k_out = haliax.jax_utils.maybe_rng_split(key, 4) - q_proj = hnn.Linear.init(In=Embed, Out=(Heads, HeadSize), key=k_q, use_bias=use_bias, out_first=False) - k_proj = hnn.Linear.init(In=Embed, Out=(Heads, HeadSize), key=k_k, use_bias=False, out_first=False) - v_proj = hnn.Linear.init(In=Embed, Out=(Heads, HeadSize), key=k_v, use_bias=use_bias, out_first=False) - out_proj = hnn.Linear.init(In=(Heads, HeadSize), Out=Embed, key=k_out, use_bias=use_bias, out_first=False) + q_proj = hnn.Linear.init( + In=Embed, + Out=(Heads, HeadSize), + key=k_q, + use_bias=use_bias, + ) + k_proj = hnn.Linear.init( + In=Embed, + Out=(Heads, HeadSize), + key=k_k, + use_bias=False, + ) + v_proj = hnn.Linear.init( + In=Embed, + Out=(Heads, HeadSize), + key=k_v, + use_bias=use_bias, + ) + out_proj = hnn.Linear.init( + In=(Heads, HeadSize), + Out=Embed, + key=k_out, + use_bias=use_bias, + ) return WhisperAttention(config, q_proj, k_proj, v_proj, out_proj, inference=False)