diff --git a/src/levanter/models/qwen.py b/src/levanter/models/qwen.py new file mode 100644 index 000000000..807a768ad --- /dev/null +++ b/src/levanter/models/qwen.py @@ -0,0 +1,282 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Type + +import equinox as eqx +import jax.numpy as jnp +import jax.random as jrandom + +import haliax as hax +import haliax.nn as hnn +from haliax import Axis, NamedArray +from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split +from haliax.nn.scan import Stacked +from haliax.state_dict import ModuleWithStateDictSerialization + +from levanter.compat.hf_checkpoints import HFCheckpointConverter +from levanter.logging import silence_transformer_nag +from levanter.models.attention import AttentionMask, dot_product_attention +from levanter.models.llama import LlamaConfig, LlamaEmbedding, LlamaMlp, LlamaRMSNorm, LlamaTransformer +from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.models.rotary import RotaryEmbeddingsConfig +from levanter.types import BlockFoldable +from levanter.utils.flop_utils import lm_flops_per_token + + +silence_transformer_nag() +from transformers import PretrainedConfig as HfConfig # noqa: E402 +from transformers import Qwen2Config as HfQwenConfig # noqa: E402 + + +@LmConfig.register_subclass("qwen") +@dataclass(frozen=True) +class QwenConfig(LlamaConfig): + """Extends LlamaConfig with Qwen specific features""" + + use_sliding_window: bool = False + sliding_window: Optional[int] = None + max_window_layers: int = 0 # Only apply sliding window beyond this layer + + def __post_init__(self): + assert ( + self.num_heads % self.num_kv_heads == 0 + ), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." + + def hf_checkpoint_converter(self) -> HFCheckpointConverter["QwenConfig"]: # type: ignore + return HFCheckpointConverter( + self.__class__, + reference_checkpoint=self.reference_checkpoint, + trust_remote_code=True, + tokenizer=self.tokenizer if self.tokenizer else self.reference_checkpoint, + HfConfigClass=HfQwenConfig, + ) + + @classmethod + def from_hf_config(cls, hf_config: HfConfig): + rope_theta = hf_config.rope_theta + rope_config = RotaryEmbeddingsConfig.from_hf_config(rope_theta, hf_config.rope_scaling) + return QwenConfig( + seq_len=hf_config.max_position_embeddings, + hidden_dim=hf_config.hidden_size, + intermediate_dim=hf_config.intermediate_size, + num_layers=hf_config.num_hidden_layers, + num_heads=hf_config.num_attention_heads, + num_kv_heads=hf_config.num_key_value_heads, + use_sliding_window=getattr(hf_config, "use_sliding_window", False), + sliding_window=getattr(hf_config, "sliding_window", None), + max_window_layers=getattr(hf_config, "max_window_layers", 0), + activation_function=hf_config.hidden_act, + initializer_range=hf_config.initializer_range, + layer_norm_epsilon=hf_config.rms_norm_eps, + tie_word_embeddings=hf_config.tie_word_embeddings, + rope=rope_config, + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfQwenConfig: + if config_overrides is None: + config_overrides = {} + + rope_theta, rope_scaling = self.rope.to_hf_config() + + return HfQwenConfig( + max_position_embeddings=self.seq_len, + hidden_size=self.hidden_dim, + intermediate_size=self.intermediate_dim, + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + use_sliding_window=self.use_sliding_window, + sliding_window=self.sliding_window, + max_window_layers=self.max_window_layers, + hidden_act=self.activation_function, + initializer_range=self.initializer_range, + rms_norm_eps=self.layer_norm_epsilon, + tie_word_embeddings=self.tie_word_embeddings, + vocab_size=vocab_size, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + **config_overrides, + ) + + @property + def model_type(self) -> Type["QwenLMHeadModel"]: + return QwenLMHeadModel + + def flops_per_token(self, vocab_size: int): + return lm_flops_per_token( + hidden_dim=self.hidden_dim, + intermediate_dim=self.intermediate_dim, + num_layers=self.num_layers, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + seq_len=self.seq_len, + vocab_size=vocab_size, + glu=True, + ) + + +# Modified attention class for Qwen +class QwenAttention(eqx.Module): + config: QwenConfig = eqx.static_field() + q_proj: hnn.Linear + k_proj: hnn.Linear + v_proj: hnn.Linear + o_proj: hnn.Linear + + @staticmethod + def init(config: QwenConfig, *, key) -> "QwenAttention": + Embed = config.Embed + QHeadsPerGroup = hax.Axis("q_heads_per_group", config.num_heads // config.num_kv_heads) + + k_q, k_k, k_v, k_o = jrandom.split(key, 4) + q_proj = hnn.Linear.init( + In=Embed, + Out=(config.KVHeads, QHeadsPerGroup, config.HeadSize), + key=k_q, + use_bias=True, # Qwen always uses bias in attention + out_first=True, + ) + k_proj = hnn.Linear.init( + In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_k, use_bias=True, out_first=True + ) + v_proj = hnn.Linear.init( + In=Embed, Out=(config.KVHeads, config.HeadSize), key=k_v, use_bias=True, out_first=True + ) + o_proj = hnn.Linear.init( + In=(config.Heads, config.HeadSize), + Out=Embed, + key=k_o, + use_bias=False, # Qwen doesn't use bias in o_proj + out_first=True, + ) + return QwenAttention(config, q_proj, k_proj, v_proj, o_proj) + + @named_call + def __call__( + self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], layer_idx: int = 0, *, key=None + ) -> NamedArray: + key_q, key_k, key_v, key_o = maybe_rng_split(key, 4) + + # QKV projections + q = self.q_proj(x, key=key_q).rearrange((..., "kv_heads", "q_heads_per_group", "position", "head_size")) + k = self.k_proj(x, key=key_k).rearrange((..., "kv_heads", "position", "head_size")) + v = self.v_proj(x, key=key_v).rearrange((..., "kv_heads", "position", "head_size")) + + # Apply rotary embeddings + rot_embs = self.config.rope.build(self.config.HeadSize, q.resolve_axis("position")) + q, k = rot_embs(self.config.HeadSize, q, k) + + k = k.rename({"position": "key_position"}) + v = v.rename({"position": "key_position"}) + + # Apply sliding window attention if configured and past max_window_layers + if ( + self.config.use_sliding_window + and self.config.sliding_window is not None + and layer_idx >= self.config.max_window_layers + ): + raise ValueError("Sliding Window Attention is not currently supported.") + + # Perform attention + attn_output = dot_product_attention( + "position", + "key_position", + "head_size", + q, + k, + v, + mask, + attention_dtype=jnp.float32 if self.config.upcast_attn else x.dtype, + use_flash=self.config.use_flash_attention, + attn_backend=self.config.attn_backend, + flash_block_size=self.config.flash_attention_block_size, + ) + + attn_output = attn_output.flatten_axes(("kv_heads", "q_heads_per_group"), "heads") + attn_output = attn_output.astype(x.dtype) + + attn_output = self.o_proj(attn_output, key=key_o) + return attn_output + + +# Modified decoder layer for Qwen +class QwenDecoderLayer(eqx.Module): + config: QwenConfig = eqx.static_field() + self_attn: QwenAttention + mlp: LlamaMlp # Can reuse Llama MLP as structure is similar + input_layernorm: LlamaRMSNorm + post_attention_layernorm: LlamaRMSNorm + + @staticmethod + def init(config: QwenConfig, *, key) -> "QwenDecoderLayer": + k_attn, k_mlp = jrandom.split(key, 2) + + attn = QwenAttention.init(config, key=k_attn) + mlp = LlamaMlp.init( + config.Embed, + config.Mlp, + config.activation_function, + key=k_mlp, + use_bias=config.use_bias, + ) + ln_1 = config.mk_LayerNorm(config.Embed) + ln_2 = config.mk_LayerNorm(config.Embed) + + return QwenDecoderLayer(config, attn, mlp, ln_1, ln_2) + + @named_call + def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, key=None) -> NamedArray: + k_attn, k_mlp = maybe_rng_split(key, 2) + + residual = x + x = self.input_layernorm(x) + attn_output = self.self_attn(x=x, mask=mask, key=k_attn) + x = residual + attn_output + + residual = x + x = self.post_attention_layernorm(x) + mlp_output = self.mlp(x, key=k_mlp) + output = residual + mlp_output + return output + + +# Modified transformer for Qwen +class QwenTransformer(LlamaTransformer): + config: QwenConfig = eqx.static_field() + layers: BlockFoldable[QwenDecoderLayer] + norm: LlamaRMSNorm + + @staticmethod + def init(config: QwenConfig, *, key) -> "QwenTransformer": + S = Stacked + if not config.scan_layers: + from haliax.nn.scan import BlockSeq + + S = BlockSeq + + # Initialize layers with their indices + layers = S.init(config.Layers, QwenDecoderLayer, gradient_checkpointing=config.gradient_checkpointing)( + config, + key=shaped_rng_split(key, config.num_layers), + ) + + ln_f = config.mk_LayerNorm(config.Embed) + return QwenTransformer(config, layers, ln_f) + + +# Modified LM head model for Qwen +class QwenLMHeadModel(LmHeadModel[QwenConfig], ModuleWithStateDictSerialization): + transformer: QwenTransformer + embeddings: LlamaEmbedding # Can reuse Llama embeddings + lm_head: Optional[hnn.Linear] + + @classmethod + def init(cls, Vocab: Axis, config: QwenConfig, *, key) -> "QwenLMHeadModel": + k_t, k_emb = jrandom.split(key, 2) + transformer = QwenTransformer.init(config, key=k_t) + embeddings = LlamaEmbedding.init(Vocab, config, key=k_emb) + if config.tie_word_embeddings: + lm_head = None + else: + lm_head = hnn.Linear.init(In=config.Embed, Out=Vocab, key=k_emb, use_bias=False, out_first=True) + + return QwenLMHeadModel(transformer, embeddings, lm_head) diff --git a/tests/test_qwen2.py b/tests/test_qwen2.py new file mode 100644 index 000000000..58527878c --- /dev/null +++ b/tests/test_qwen2.py @@ -0,0 +1,117 @@ +import json +import tempfile + +import numpy as np +from jax import random + +import haliax as hax + +from levanter.models.attention import AttentionMask +from levanter.models.qwen import QwenConfig, QwenLMHeadModel +from test_utils import skip_if_no_torch + + +def get_config(vocab_size=1000): + from transformers import Qwen2Config + + qwen_cfg = json.loads( + """ + { + "architectures": ["QWenLMHeadModel"], + "attn_dropout_prob": 0.0, + "bf16": false, + "emb_dropout_prob": 0.0, + "fp16": false, + "fp32": false, + "hidden_size": 4096, + "intermediate_size": 22016, + "initializer_range": 0.02, + "kv_channels": 128, + "layer_norm_epsilon": 1e-06, + "max_position_embeddings": 32768, + "model_type": "qwen", + "no_bias": true, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "onnx_safe": null, + "rotary_emb_base": 10000, + "rotary_pct": 1.0, + "scale_attn_weights": true, + "seq_length": 8192, + "tie_word_embeddings": false, + "tokenizer_class": "QWenTokenizer", + "transformers_version": "4.32.0", + "use_cache": true, + "use_dynamic_ntk": true, + "use_flash_attn": "auto", + "use_logn_attn": true, + "vocab_size": 151936 + } + """ + ) + qwen_config: Qwen2Config = Qwen2Config.from_dict(qwen_cfg) + qwen_config.hidden_size = 16 + qwen_config.intermediate_size = 64 + qwen_config.num_attention_heads = 4 + qwen_config.head_dim = 4 + qwen_config.num_hidden_layers = 4 + qwen_config.num_key_value_heads = 2 + qwen_config.max_position_embeddings = 128 + qwen_config.vocab_size = vocab_size + return qwen_config + + +@skip_if_no_torch +def test_qwen_roundtrip(): + import torch + from transformers import Qwen2ForCausalLM + + Vocab = hax.Axis("vocab", 1000) + hf_config = get_config(Vocab.size) + + converter = QwenConfig().hf_checkpoint_converter() + + config = QwenConfig.from_hf_config(hf_config) + + # Make input and attn_mask + input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + attn_mask = AttentionMask.causal() + input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) + + torch.random.manual_seed(0) + + torch_model = Qwen2ForCausalLM(hf_config) + torch_model.eval() + + torch_out = torch_model(input_torch) + torch_out = torch_out.logits[0].detach().cpu().numpy() + # torch_out = jax.nn.softmax(torch_out, axis=-1) + + with tempfile.TemporaryDirectory() as tmpdir: + torch_model.save_pretrained(f"{tmpdir}/torch_model") + + model = converter.load_pretrained( + QwenLMHeadModel, ref=f"{tmpdir}/torch_model", resize_vocab_to_match_tokenizer=False + ) + + @hax.named_jit + def compute(model, input): + model_output = model(input, attn_mask=attn_mask) + return model_output + + jax_out = compute(model, input).array + + assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-4, atol=1e-4).all(), f"{torch_out} != {jax_out}" + + # now we're going to magnify the model parameters enough that differences should actualy show up + jax_out = compute(model, input).array + + converter.save_pretrained(model, f"{tmpdir}/lev_model", save_reference_code=False) + torch_model2 = Qwen2ForCausalLM.from_pretrained(f"{tmpdir}/lev_model") + torch_model2.eval() + + torch_out2 = torch_model2(input_torch) + torch_out2 = torch_out2.logits[0].detach().cpu().numpy() + assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" + np.testing.assert_allclose(torch_out2, jax_out, rtol=1e-5, atol=1e-5)