From 3a84009c750a5e29847ead91cbc22d4e6c29d6b5 Mon Sep 17 00:00:00 2001 From: Russell Power Date: Sat, 18 May 2024 15:23:15 -0700 Subject: [PATCH] Add factorized llama model for testing. --- .gitignore | 1 + config/distill_llama3_8b.yaml | 45 ++ config/distill_llama3_tiny.yaml | 41 ++ src/levanter/layerwise_trainer.py | 740 ++++++++++++++++++++++++ src/levanter/main/train_distill_lm.py | 141 +++++ src/levanter/models/factorized_llama.py | 627 ++++++++++++++++++++ src/levanter/models/llama.py | 11 +- src/levanter/tracker/tracker.py | 10 +- src/levanter/trainer_state.py | 2 +- tests/test_factorized_llama.py | 281 +++++++++ 10 files changed, 1891 insertions(+), 8 deletions(-) create mode 100644 config/distill_llama3_8b.yaml create mode 100644 config/distill_llama3_tiny.yaml create mode 100644 src/levanter/layerwise_trainer.py create mode 100644 src/levanter/main/train_distill_lm.py create mode 100644 src/levanter/models/factorized_llama.py create mode 100644 tests/test_factorized_llama.py diff --git a/.gitignore b/.gitignore index 835da2048..3632a7f1a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /scratch +/cache # Configuration for TPU launches/secrets .config diff --git a/config/distill_llama3_8b.yaml b/config/distill_llama3_8b.yaml new file mode 100644 index 000000000..3c61eadce --- /dev/null +++ b/config/distill_llama3_8b.yaml @@ -0,0 +1,45 @@ +data: + id: dlwh/wikitext_103_detokenized + tokenizer: "meta-llama/Meta-Llama-3-8B" + cache_dir: gs://wasabi-tpu-training/wikitext-103-detokenized + +teacher: + type: llama + reference_checkpoint: "meta-llama/Meta-Llama-3-8B" + gradient_checkpointing: True + seq_len: 4096 + hidden_dim: 4096 + intermediate_dim: 14336 + num_layers: 32 + num_heads: 32 + num_kv_heads: 8 + use_flash_attention: False + +student: + type: factorized_llama + reference_checkpoint: "meta-llama/Meta-Llama-3-8B" + gradient_checkpointing: True + seq_len: 4096 + hidden_dim: 4096 + intermediate_dim: 14336 + num_layers: 32 + num_heads: 32 + num_kv_heads: 8 + use_flash_attention: False + factor_dim: 128 + +trainer: + mp: p=bf16,c=bfloat16 + train_batch_size: 64 + num_train_steps: 10000 + steps_per_eval: 5000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + load_checkpoint_path: "gs://wasabi-tpu-training/distill-8b/checkpoints" + + +optimizer: + learning_rate: 1.2E-5 # set low for fine-tuning + weight_decay: 0.1 + min_lr_ratio: 0.1 \ No newline at end of file diff --git a/config/distill_llama3_tiny.yaml b/config/distill_llama3_tiny.yaml new file mode 100644 index 000000000..3d141d4f6 --- /dev/null +++ b/config/distill_llama3_tiny.yaml @@ -0,0 +1,41 @@ +data: + id: dlwh/wikitext_103_detokenized + tokenizer: "meta-llama/Meta-Llama-3-8B" + cache_dir: gs://wasabi-tpu-training/wikitext-103-detokenized + +teacher: + type: llama + seq_len: 4096 + hidden_dim: 64 + intermediate_dim: 64 + num_layers: 32 + num_heads: 4 + num_kv_heads: 2 + use_flash_attention: True + +student: + type: factorized_llama + seq_len: 4096 + hidden_dim: 64 + intermediate_dim: 64 + factor_dim: 16 + num_layers: 32 + num_heads: 4 + num_kv_heads: 2 + use_flash_attention: True + +trainer: + mp: p=bf16,c=bfloat16 + train_batch_size: 256 + num_train_steps: 10000 + steps_per_eval: 5000 + tensor_parallel_axes: ["mlp", "heads"] + fsdp_axis: "embed" + batch_axis: "batch" + load_checkpoint_path: "gs://wasabi-tpu-training/distill-tiny/checkpoints" + + +optimizer: + learning_rate: 1.2E-5 # set low for fine-tuning + weight_decay: 0.1 + min_lr_ratio: 0.1 \ No newline at end of file diff --git a/src/levanter/layerwise_trainer.py b/src/levanter/layerwise_trainer.py new file mode 100644 index 000000000..a9a0f1aa8 --- /dev/null +++ b/src/levanter/layerwise_trainer.py @@ -0,0 +1,740 @@ +"""Trainer for layer-wise distillation. + +This accepts a configuration with a teacher and a student model. +The student model is assumed to have the same output dimensions for each layer, +but will typically have a factorized internal structure in order to reduce the +number of parameters. +""" + +import atexit +import copy +import dataclasses +import functools +import logging as pylogging +import os +import sys +import typing +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + List, + Mapping, + Optional, + Protocol, + Sequence, + Tuple, + TypeVar, + Union, +) + +import equinox as eqx +import haliax as hax +import jax +import jax.numpy as jnp +import jmp +import numpy as np +from draccus import field +from haliax import Axis +from haliax.partitioning import ResourceAxis, ResourceMapping, named_jit +from haliax.quantization import ( + Fp8Config, + fp8_linear_layers, +) +from haliax.types import IntScalar, Scalar +from jax.experimental import multihost_utils +from jax.sharding import Mesh +from jaxtyping import PRNGKeyArray, PyTree +from optax import GradientTransformation, OptState + +import levanter.checkpoint +import levanter.logging +import levanter.tracker +import levanter.tracker.wandb +from levanter import tracker +from levanter.checkpoint import ( + CheckpointerConfig, + discover_latest_checkpoint, + load_checkpoint, + load_checkpoint_or_initialize, +) +from levanter.config import JsonAtom +from levanter.data import ( + Dataset, + ReplicatedBatchLoader, + ShardableDataset, + ShardedBatchLoader, +) +from levanter.distributed import DistributedConfig, RayConfig +from levanter.logging import capture_time +from levanter.models.lm_model import LmExample +from levanter.tracker import TrackerConfig +from levanter.trainer_state import ( + _ensure_int_is_array, + cast_params_by_trainability, + init_optimizer_for_trainables, + saveable_training_mask, + take_train_step, + trainables_only, +) +from levanter.types import ( + ComputeLossFunction, + FilterSpec, + FilterTree, + ModuleComputeLoss, +) +from levanter.utils import cloud_utils, fsspec_utils +from levanter.utils.tree_utils import inference_mode + +logger = pylogging.getLogger(__name__) + +M = TypeVar("M") # Model +X = TypeVar("X") # Input +S = TypeVar("S") +Student = TypeVar("Student") +Teacher = TypeVar("Teacher") + +DEFAULT_JAX_CONFIG = { + "jax_threefry_partitionable": True, + "jax_softmax_custom_jvp": True, +} + + +def _per_layer_loss(inputs, layer, key, mask): + loss, teacher_x = inputs + key, layer_k = jax.random.split(key, 2) + student, teacher = layer + student_y = student(teacher_x, mask=mask, key=layer_k) + teacher_y = teacher(teacher_x, mask=mask, key=layer_k) + loss = hax.mean((teacher_y - student_y) ** 2) + return (loss, teacher_y) + + +def _layer_loss(student, teacher, batch, key, compute_axis_mapping): + with hax.axis_mapping(compute_axis_mapping): + teacher_x = teacher.embeddings.embed(batch.tokens) + student_layers = student.transformer.layers.stacked + teacher_layers = teacher.transformer.layers.stacked + initial_loss = hax.NamedArray(0.0, axes=()) + + block_axis = student.transformer.layers.Block + + def _is_block(leaf): + if hax.is_named_array(leaf): + print(leaf.shape, leaf.axes) + return hax.is_named_array(leaf) and leaf.axes[0] == block_axis + + _loss_fn = hax.filter_checkpoint(_per_layer_loss) + + loss, _ = hax.fold( + _loss_fn, + axis=block_axis, + is_scanned=_is_block, + )( + (initial_loss, teacher_x), + (student_layers, teacher_layers), + key=key, + mask=batch.attn_mask, + ) + return loss.scalar() + + +class TrainerState(eqx.Module): + step: IntScalar = eqx.field(converter=_ensure_int_is_array) + student: Student + teacher: Teacher # Can't be static, breaks tracing for some reason. + optimizer: GradientTransformation = eqx.field(static=True) + opt_state: OptState + training_key: PRNGKeyArray + + is_trainable: FilterTree = eqx.field(static=True) + mp: jmp.Policy = eqx.field(static=True) + + @property + def trainable_model(self) -> M: + return trainables_only(self.model, self.is_trainable) + + @property + def saveable_state(self) -> FilterTree: + return eqx.filter(self, saveable_training_mask(self, self.is_trainable)) + + @classmethod + def init( + cls, + optimizer: GradientTransformation, + student: Student, + teacher: Teacher, + *args, + key: PRNGKeyArray, + is_trainable: FilterTree = True, + mp: Optional[jmp.Policy] = None, + fp8: Fp8Config = None, + **kwargs, + ) -> "TrainerState": + if mp is not None: + student = cast_params_by_trainability(student, mp, is_trainable) + else: + mp = jmp.get_policy("f32") + + if fp8 is not None: + student = fp8_linear_layers(student, fp8) + + opt_state = init_optimizer_for_trainables(optimizer, student, is_trainable) + return cls( + 0, + student=student, + teacher=teacher, + optimizer=optimizer, + opt_state=opt_state, + training_key=key, + is_trainable=is_trainable, + mp=mp, + *args, + **kwargs, + ) + + def take_step(self: S, grads: PyTree, obj_fun: Optional[Callable[[M], Scalar]] = None) -> S: + assert isinstance(self, TrainerState) # make mypy happy + student, opt_state = take_train_step( + optimizer=self.optimizer, + model=self.student, + opt_state=self.opt_state, + grads=grads, + obj_fun=obj_fun, + is_trainable=self.is_trainable, + ) + return dataclasses.replace(self, student=student, opt_state=opt_state, step=self.step + 1) + + +def init_model( + model_init: Optional[Callable[[], M]], + checkpoint_path: Path, + axis_mapping: ResourceMapping, + device_mesh: Mesh, +): + if not checkpoint_path or not fsspec_utils.exists(checkpoint_path): + return model_init() + + checkpoint_path = discover_latest_checkpoint(checkpoint_path) + + if checkpoint_path: + loaded_model = load_checkpoint_or_initialize( + model_init, + checkpoint_path, + axis_mapping=axis_mapping, + mesh=device_mesh, + subpath="model", + do_load=True, + )() + return loaded_model + else: + return model_init() + + +class Trainer: + config: "TrainerConfig" + optimizer: GradientTransformation + is_trainable_param: PyTree[FilterSpec] + _raw_loss_function: Callable + _cmanagers: List[typing.ContextManager] = [] + + def __init__( + self, + config: "TrainerConfig", + optimizer: GradientTransformation, + loss_fn: Optional[ComputeLossFunction] = None, + ): + """ + + Args: + config: the trainer config + optimizer: the optimizer, e.g. `optax.adam(1e-3)` or produced by [levanter.optim.OptimizerConfig][] + loss_fn (Callable): the loss function. This should be a function that takes a model and some inputs and returns a + scalar loss. It should be jit-able and should not have any side effects. + """ + self.config = config + self.optimizer = optimizer + self._raw_loss_function = loss_fn or ModuleComputeLoss() + if isinstance(config.tracker, Sequence): + self.tracker = levanter.tracker.CompositeTracker([c.init(self.run_id) for c in config.tracker]) + else: + self.tracker = config.tracker.init(self.run_id) + + self._raw_loss_function = loss_fn or ModuleComputeLoss() + if isinstance(config.tracker, Sequence): + self.tracker = levanter.tracker.CompositeTracker([c.init(self.run_id) for c in config.tracker]) + else: + self.tracker = config.tracker.init(self.run_id) + + self._cmanagers = [] + + @cached_property + def loss_fn(self): + """ + Wrapped loss function that casts the model to compute precision and sets the context axis mapping to compute + """ + + @functools.wraps(self._raw_loss_function) + def fn(model, *batch, **batch_kwargs): + with hax.axis_mapping(self.compute_axis_mapping): + model = self.mp.cast_to_compute(model) + return _ensure_scalar(self._raw_loss_function(model, *batch, **batch_kwargs)) + + return fn + + @property + def run_id(self) -> str: + """Returns the run id""" + assert self.config.id is not None + return self.config.id + + @property + def mp(self) -> jmp.Policy: + """Returns the mixed precision policy""" + return self.config.mp + + @property + def fp8(self) -> Optional[Fp8Config]: + if self.config.fp8 is True: + return Fp8Config() + elif self.config.fp8 is False: + return None + else: + return self.config.fp8 + + @property + def num_train_steps(self) -> int: + return self.config.num_train_steps + + @property + def parameter_axis_mapping(self) -> ResourceMapping: + return self.config.parameter_axis_mapping + + @property + def compute_axis_mapping(self) -> ResourceMapping: + return self.config.compute_axis_mapping + + @property + def device_mesh(self) -> Mesh: + return self.config.device_mesh + + @property + def TrainBatch(self): + return self.config.TrainBatch + + @property + def EvalBatch(self): + return self.config.EvalBatch + + def __enter__(self): + if len(self._cmanagers) > 0: + raise RuntimeError("Trainer is already entered") + + self._cmanagers = [ + levanter.current_tracker(self.tracker), + self.device_mesh, + hax.axis_mapping(self.parameter_axis_mapping), + ] + + for cmanager in self._cmanagers: + cmanager.__enter__() + + return self + + def __exit__(self, *args): + problems = [] + for cmanager in reversed(self._cmanagers): + try: + cmanager.__exit__(*args) + except Exception as e: + problems.append(e) + + self._cmanagers = [] + + if len(problems) > 0: + raise RuntimeError("Exception(s) occurred while exiting trainer", problems) from problems[0] + + def initial_state( + self, + training_key: PRNGKeyArray, + student_init, + teacher_init, + *, + is_trainable: PyTree[FilterSpec] = True, + ) -> TrainerState: + """ + Either loads a checkpoint or initializes a fresh trainer state. This is the recommended way to initialize + a trainer state. + + This method is smart enough to handle subclasses of TrainerState. If you want to extend TrainerState, you + can override _initialize_state_from_scratch + + Args + is_trainable: optional filter spec for the trainable parameters. This is used to filter out non-trainable + parameters for the optimizer state and for computing gradients. Non-trainable parameters are also + not checkpointed. If you don't specify this, all parameters are assumed to be trainable. + + Returns: + TrainerState: the initial state, + """ + + def init_state_and_model(training_key): + student = init_model( + model_init=student_init, + checkpoint_path=self.config.load_checkpoint_path, + axis_mapping=self.parameter_axis_mapping, + device_mesh=self.device_mesh, + ) + teacher = init_model( + model_init=teacher_init, + checkpoint_path=self.config.load_checkpoint_path, + axis_mapping=self.parameter_axis_mapping, + device_mesh=self.device_mesh, + ) + # only force trainable params to param precision. Other params are cast to compute precision + state = TrainerState.init( + self.optimizer, + student=student, + teacher=teacher, + key=training_key, + is_trainable=is_trainable, + mp=self.mp, + fp8=self.fp8, + ) + return state + + trainer_state_shape = eqx.filter_eval_shape(init_state_and_model, training_key) + saveable_train_state = saveable_training_mask(trainer_state_shape, is_trainable) + + state = load_checkpoint_or_initialize( + init_state_and_model, + self.checkpoint_path, + axis_mapping=self.parameter_axis_mapping, + mesh=self.device_mesh, + is_checkpointed=saveable_train_state, + do_load=load_checkpoint, + )(training_key) + + return state + + @property + def checkpoint_path(self) -> str: + checkpoint_path = self.config.load_checkpoint_path + if checkpoint_path is None: + checkpoint_path = self.config.checkpointer.expanded_path(self.run_id) + return checkpoint_path + + def replicated_loader(self, dataset: Dataset[X], batch_axis: Axis) -> ReplicatedBatchLoader[X]: + """Creates a replicated batch loader for the given dataset. Generally you should use this + if you either be able to make a single pass over the dataset. + + Args: + dataset (Dataset): the dataset to load + batch_axis (Axis): the batch axis + + Returns: + ReplicatedBatchLoader: the batch loader + """ + return ReplicatedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) + + def sharded_loader(self, dataset: ShardableDataset[X], batch_axis: Axis) -> ShardedBatchLoader[X]: + """Creates a sharded batch loader for the given dataset. Generally you should use this + for training and you don't care about epoch boundaries. + + Args: + dataset (Dataset): the dataset to load + batch_axis (Axis): the batch axis + + Returns: + ShardedBatchLoader: the batch loader + """ + return ShardedBatchLoader(dataset, self.device_mesh, batch_axis, self.compute_axis_mapping) + + def train(self, state: S, loader: ReplicatedBatchLoader[X]) -> S: + for step in range(self.num_train_steps): + batch = next(loader) + state, loss = self.train_step(state, batch) + print("Training:", step, loss) + return state + + def train_step(self, state: S, *batch: X, **batch_kwargs): + """ + Performs a single training step. + """ + with capture_time() as step_time: + loss, new_state = self._jit_train_step_fn(state, *batch, **batch_kwargs) + # force the loss so timing numbers are accurate. laziness isn't going to help here (i think?) + loss = loss.item() # type: ignore + + return new_state, loss + + @cached_property + def _jit_train_step_fn(self): + return named_jit( + self._layerwise_train_step, + axis_resources=self.parameter_axis_mapping, + out_axis_resources=self.parameter_axis_mapping, + donate_args=(True,), + ) + + def _layerwise_train_step(self, state: S, batch: LmExample) -> tuple[Scalar, S]: + student = inference_mode(state.student, False) + teacher = inference_mode(state.teacher, True) + + # tokens: hax.NamedArray + # loss_mask: hax.NamedArray + # attn_mask: AttentionMask | NamedArray = AttentionMask.causal() + + # manually thread the teacher and student models + k_t, key = jax.random.split(state.training_key, 2) + loss, grad = eqx.filter_value_and_grad(_layer_loss)(student, teacher, batch, key, self.compute_axis_mapping) + new_state = state.take_step(grad) + new_state = hax.shard(new_state, self.parameter_axis_mapping) + return loss, new_state + + +def _initialize_global_tracker(config: TrackerConfig | Tuple[TrackerConfig, ...], run_id: Optional[str]): + if isinstance(config, Sequence): + tracker = levanter.tracker.CompositeTracker([c.init(run_id) for c in config]) + else: + tracker = config.init(run_id) + + levanter.tracker.set_global_tracker(tracker) + + +@dataclass +class TrainerConfig: + seed: int = 0 # random seed + mp: jmp.Policy = jmp.get_policy("f32") # mixed precision policy + fp8: Optional[bool | Fp8Config] = None + + log_dir: Path = Path("logs/") + run_base_dir: Path = Path("runs/") + id: Optional[str] = None # run id. if None, will be set to a random string + + tracker: TrackerConfig | Tuple[TrackerConfig, ...] = field(default_factory=tracker.NoopConfig) + + # TODO: refactor callbacks + profiler: bool = False + profiler_start_step: int = 5 + profiler_num_steps: int = 100 + profiler_perfetto_link: bool = False + + # config related to partitioning + + batch_axis: Optional[str] = "batch" # Batch axis for data parallel. + fsdp_axis: Optional[Union[str, List[str]]] = "embed" # Axis/Axes to use for FSDP + tensor_parallel_axes: Optional[List[str]] = None # Axes, if any, to use for tensor parallelism + + # TODO: in theory we can support tuples of physical axis names, but I don't think anyone actually uses that. + axis_resources: Mapping[str, str] = field(default_factory=dict) + """mapping from logical axis to physical axis. batch_axis, fsdp_axis, and tensor_parallel_axes are preferred""" + parameter_axis_resources: Mapping[str, str] = field(default_factory=dict) # overrides axis_mapping for parameter + """logical->physical mapping for parameter/optimizer sharding. fsdp_axis and tensor_parallel_axes are preferred""" + model_axis_size: int = 1 # how many devices to shard each model over. Data axis is the other axis + + # Config related to batch sizes + train_batch_size: int = 512 + per_device_parallelism: int = -1 + """how many examples to process in parallel on each device. -1 (default) means train_batch_size/num_devices""" + + per_device_eval_parallelism: int = -1 + """how many examples to process in parallel on each device. -1 (default) means same as per_device_parallelism""" + + # Config related to duration + num_train_steps: int = 400_000 # number of training steps + steps_per_eval: int = 1_000 # how often to evaluate + max_eval_batches: Optional[int] = None # max number of batches to evaluate on. None means all batches + + checkpointer: CheckpointerConfig = field(default_factory=CheckpointerConfig) + load_checkpoint: Optional[bool] = None + """if None (default), we'll load a checkpoint if it exists. If true, we must load a checkpoint""" + load_checkpoint_path: Optional[str] = None + """can be a parent (to find latest) or a specific checkpoint. if None, will set to checkpointer.base_path.""" + initialize_from: Optional[str] = None # Levanter trainer checkpoint to initialize from + + jax_config: Dict[str, JsonAtom] = field( + default_factory=lambda: copy.deepcopy(DEFAULT_JAX_CONFIG) + ) # config to pass to jax.config.update + + distributed: DistributedConfig = DistributedConfig() + ray: RayConfig = field(default_factory=RayConfig) + + # whether or not to require an accelerator (e.g. TPU or GPU). + # default depends on the platform: on macos False, else True + require_accelerator: Optional[bool] = None + + # whether or not to shutdown the tpu at exit. If a float, shutdown after that many seconds. True = 5 minutes + shutdown_at_exit: Union[bool, float] = False + + @property + def TrainBatch(self): + return Axis("batch", self.train_batch_size) + + @property + def EvalBatch(self): + return Axis("batch", self.eval_batch_size) + + @property + def microbatch_size(self): + return self.per_device_parallelism * self.data_axis_size + + def initialize(self): + """Initializes jax, logging, setting the run name/id in the process""" + self._initialize_jax_config() + # Can't do full logging setup until we've initialized jax b/c we use jax for rank id + pylogging.basicConfig(level=pylogging.INFO) + self.distributed.initialize() + self._validate_and_set_defaults() + + id = self._maybe_set_id() + levanter.logging.init_logging(self.log_dir, f"{id}.log") + _initialize_global_tracker(self.tracker, id) + + self.ray.initialize() + + if self.require_accelerator is None: + self.require_accelerator = not sys.platform.startswith("darwin") + + if self.require_accelerator: + if jax.default_backend() == "cpu": + raise RuntimeError("No accelerator found. Please run on a TPU or GPU.") + + if self.shutdown_at_exit is not False: + if isinstance(self.shutdown_at_exit, bool): + self.shutdown_at_exit = 5.0 * 60 + logger.info(f"At end of run, shutting down TPU VM in {self.shutdown_at_exit} seconds") + atexit.register(cloud_utils.shutdown_tpu_vm, self.shutdown_at_exit) + + @cached_property + def device_mesh(self) -> Mesh: + devices = jax.devices() + devices = np.array(devices).reshape(self.data_axis_size, self.model_axis_size) + return Mesh(devices, (ResourceAxis.DATA, ResourceAxis.MODEL)) + + @property + def eval_batch_size(self): + return self.per_device_eval_parallelism * self.data_axis_size + + @property + def data_axis_size(self): + """size of the data parallel/batch parallel axis.""" + assert jax.device_count() % self.model_axis_size == 0 + return jax.device_count() // self.model_axis_size + + @cached_property + def compute_axis_mapping(self) -> ResourceMapping: + """Mapping from logical axis to physical axis for compute.""" + axes_to_return = dict(self.axis_resources) + + tp_axes = self.tensor_parallel_axes or [] + if tp_axes and len(axes_to_return) > 0: + logger.warning(f"tensor parallelism axes {tp_axes} will override axis_resources {axes_to_return}") + for axis in tp_axes: + axes_to_return[axis] = ResourceAxis.MODEL + + if self.batch_axis is not None: + axes_to_return[self.batch_axis] = ResourceAxis.DATA + + return axes_to_return + + @cached_property + def parameter_axis_mapping(self) -> ResourceMapping: + mapping = dict(self.compute_axis_mapping) + + for axis, resource in self.parameter_axis_resources.items(): + mapping[axis] = resource + + if isinstance(self.fsdp_axis, str): + mapping[self.fsdp_axis] = ResourceAxis.DATA + elif isinstance(self.fsdp_axis, list): + for axis in self.fsdp_axis: + mapping[axis] = ResourceAxis.DATA + + return mapping + + def _initialize_jax_config(self): + for key, value in self.jax_config.items(): + jax.config.update(key, value) + + def _maybe_set_id(self): + # always do this so we don't get weird hangs if the id isn't set right + # for random ids, we want to ensure that all hosts have the same id + # NB: do NOT use the run seed here. we want the run id to be independent of the seed + seed = np.random.randint(0, 2**31 - 1) + seed = multihost_utils.broadcast_one_to_all(jax.numpy.array(seed, dtype=np.int32)).item() + + # RUN ID comes from a few places: the config, the environment, or wandb, or a random string + if self.id is None: + # TODO: this doesn't work with wandb sweeps. need to reconcile when we merge + if "RUN_ID" in os.environ: + self.id = os.environ["RUN_ID"] + elif self.tracker and self.tracker.id: + self.id = self.tracker.id + else: + # wandb run ids are 8 characters [a-z0-9], which we'll emulate here + # we also want to ensure that all hosts have the same run id + # we do this by syncing a random seed across all hosts and then using that to generate the run id + gen = np.random.default_rng(seed) + self.id = "".join(gen.choice(list("abcdefghijklmnopqrstuvwxyz0123456789"), 8)) + + logger.info(f"Setting run id to {self.id}") + + return self.id + + # we can't do this in post_init because we don't want to call jax.device_count before calling distributed.initialize + def _validate_and_set_defaults(self): + if jax.device_count() % self.model_axis_size != 0: + raise ValueError( + f"num_devices ({jax.device_count()}) is not divisible by model_axis_size ({self.model_axis_size})" + ) + + if ( + jax.local_device_count() % self.model_axis_size != 0 + and self.model_axis_size % jax.local_device_count() != 0 + ): + raise ValueError("either model_axis_size or local_device_count must be divisible by the other") + + assert self.train_batch_size != -1 or self.per_device_parallelism != -1 + + if self.per_device_parallelism == -1: + self.per_device_parallelism = self.train_batch_size // self.data_axis_size + + if self.train_batch_size == -1: + self.train_batch_size = self.per_device_parallelism * self.data_axis_size + + # validate size of per_device_parallelism + if self.train_batch_size % (self.per_device_parallelism * self.data_axis_size) != 0: + raise ValueError( + f"train_batch_size ({self.train_batch_size}) must be divisible by per_device_parallelism *" + f" data_axis_size ({self.per_device_parallelism}, {self.data_axis_size})" + ) + + if self.per_device_eval_parallelism == -1: + self.per_device_eval_parallelism = self.per_device_parallelism + + +class AllConfig(Protocol): + trainer: TrainerConfig + + +def initialize(config: TrainerConfig | AllConfig): + """Initializes jax, logging, setting the run name/id in the process. Also initializes tracking and saves config + as hyperparameters and an artifact""" + if isinstance(config, TrainerConfig): + trainer_config = config + else: + trainer_config = config.trainer + + trainer_config.initialize() + levanter.tracker.log_configuration(config) + + +def _ensure_scalar(x: hax.types.Scalar | hax.NamedArray) -> hax.types.Scalar: + if isinstance(x, hax.NamedArray): + return x.scalar() + else: + return x diff --git a/src/levanter/main/train_distill_lm.py b/src/levanter/main/train_distill_lm.py new file mode 100644 index 000000000..4326fa82d --- /dev/null +++ b/src/levanter/main/train_distill_lm.py @@ -0,0 +1,141 @@ +import dataclasses +import gc +import logging +import os +from dataclasses import dataclass, field +from typing import Optional, Union + +import jax.random as jrandom +import jax.numpy as jnp + +import haliax as hax +from haliax import Axis +from haliax.partitioning import named_jit, round_axis_for_partitioning, ResourceMapping + +import levanter +from levanter import callbacks +from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback +from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig +from levanter.models.factorized_llama import FactorizedLlamaConfig +from levanter.models.lm_model import LmConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.layerwise_trainer import Trainer, TrainerConfig +from levanter.utils.jax_utils import parameter_count + + +logger = logging.getLogger(__name__) + + +@dataclass +class TrainDistillLmConfig: + data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = None + trainer: TrainerConfig = field(default_factory=TrainerConfig) + teacher: LmConfig = field(default_factory=FactorizedLlamaConfig) + student: LmConfig = field(default_factory=FactorizedLlamaConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) + + fcm_prob: float = 0.0 # forgetful context masking prob. recommended 0.15 + + init_from_hf: bool = False + hf_save_path: Optional[str] = None + update_hessian_steps: int = 10 + + +def main(config: TrainDistillLmConfig): + tokenizer = config.data.the_tokenizer + + levanter.initialize(config) + optimizer = config.optimizer.build(config.trainer.num_train_steps) + + # Using the trainer as a context manager does 3 things: + # 1. Sets the device mesh + # 2. Sets the axis mapping (for fsdp) + # 3. Sets the global metrics tracker + with Trainer(config.trainer, optimizer) as trainer: + # randomness in jax is tightly controlled by "keys" which are the states of the random number generators + # this makes deterministic training pretty easy + seed = config.trainer.seed + data_key, loader_key, student_key, teacher_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 5) + + # We have two axis_mappings: one for storing the model and optimizer states, and one for compute + # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh + compute_axis_mapping = trainer.compute_axis_mapping + parameter_axis_mapping = trainer.parameter_axis_mapping + + print("Parameters", parameter_axis_mapping) + print("Compute", compute_axis_mapping) + + # some axes we need + Batch = config.trainer.TrainBatch + EvalBatch = config.trainer.EvalBatch + Pos = config.teacher.Pos + KeyPos = config.teacher.KeyPos + + tagged_eval_datasets = config.data.tagged_eval_sets(Pos.size) + train_dataset = CausalLmDataset( + config.data.train_set(Pos.size), Pos, KeyPos, ignore_index=config.data.ignore_token_id + ) + + # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to + # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of + # tokens: gpt-2 has 50257, for example. So we round up. + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + + def _load_model_from_hf(model_config: LmConfig): + # this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through, + # I recommend skipping it for now + assert isinstance(model_config, HFCompatConfig) + converter = model_config.default_hf_checkpoint_converter + if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab: + logger.warning("The tokenizers appear to be different. You may want to check this.") + + converter = converter.replaced(tokenizer=tokenizer) + # initialize from an hf pretrained model + logger.info(f"Initializing model from HF checkpoint '{converter.reference_checkpoint}'") + gc.collect() + model = converter.load_pretrained( + model_config, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.compute_dtype + ) + model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model) + return model + + state = trainer.initial_state( + training_key, + student_init=lambda: config.student.build(Vocab, key=student_key), + teacher_init=lambda: config.teacher.build(Vocab, key=teacher_key), + ) + + if int(state.step) == 0 and config.init_from_hf: + state = dataclasses.replace(state, teacher=None, student=None) + gc.collect() + teacher = _load_model_from_hf(config.teacher) + student = _load_model_from_hf(config.student) + state = dataclasses.replace(state, teacher=teacher, student=student) + + levanter.tracker.log_summary( + { + "teacher_parameter_count": parameter_count(state.teacher), + "student_parameter_count": parameter_count(state.student), + } + ) + + train_loader = iter(trainer.sharded_loader(train_dataset, Batch)) + + if int(state.step) > 0: + # step is after the batch, so we need to seek to step + # TODO: implement iter_data.seek(resume_step +1) + import tqdm + + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): + next(train_loader) + + ## OK, actually run training! + trainer.train(state, train_loader) + # checkpointer.on_step(last_step, force=True) + + +if __name__ == "__main__": + levanter.config.main(main)() diff --git a/src/levanter/models/factorized_llama.py b/src/levanter/models/factorized_llama.py new file mode 100644 index 000000000..562a3b3e8 --- /dev/null +++ b/src/levanter/models/factorized_llama.py @@ -0,0 +1,627 @@ +import dataclasses +from dataclasses import dataclass +from typing import Callable, Dict, Optional, Tuple, Type, Union + +import equinox as eqx +import haliax as hax +import haliax.nn as hnn +import jax +import jax.numpy as jnp +import jax.random as jrandom +import numpy as np +from haliax import Axis, AxisSpec, NamedArray +from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split +from haliax.nn.scan import Stacked +from haliax.util import ensure_tuple +from jaxtyping import PRNGKeyArray, PyTree + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig +from levanter.compat.torch_serialization import ( + StateDict, + StateDictSerializationMixin, + apply_prefix, + flatten_linear_layers, + stack_state_dict, + unflatten_linear_layers, + unstack_state_dict, +) +from levanter.logging import silence_transformer_nag +from levanter.models.attention import AttentionMask, dot_product_attention +from levanter.models.gpt2 import ACT2FN +from levanter.models.llama import LlamaRMSNorm +from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.types import BlockFoldable +from levanter.utils.jax_utils import leaf_key_paths + +silence_transformer_nag() +from transformers import LlamaConfig as HfLlamaConfig # noqa: E402 +from transformers import PretrainedConfig as HfConfig # noqa: E402 + + +@LmConfig.register_subclass("factorized_llama") +@dataclass(frozen=True) +class FactorizedLlamaConfig(HFCompatConfig): + """Config for FactorizedLlamaModel + + Args: + seq_len (int, optional): maximum length of the input sequence. Defaults to 2048. + hidden_dim (int, optional): dimension of the hidden state. Defaults to 4096. + intermediate_dim (int, optional): dimension of the intermediate state. Defaults to 11008. + num_layers (int, optional): number of hidden layers in the Transformer encoder. Defaults to 32. + num_heads (int, optional): number of attention heads for each attention layer. Defaults to 32. + num_kv_heads (int, optional): number of attention heads for keys and values in each attention layer. + Setting to 1 means MQA. Setting to num_heads means MHA. Otherwise GQA. + Note that num_heads must be divisible by this number. Defaults to 32. + activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". + rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. + """ + + reference_checkpoint: Optional[str] = None + + seq_len: int = 2048 + hidden_dim: int = 4096 + factor_dim: int = 512 + intermediate_dim: int = 11008 + num_layers: int = 32 + num_heads: int = 32 + num_kv_heads: int = 32 + activation_function: str = "silu" + initializer_range: float = 0.02 + layer_norm_epsilon: float = 1e-5 + + # Attention-related config + upcast_attn: bool = False + use_flash_attention: bool = True + flash_attention_block_size: Optional[int] = None + + gradient_checkpointing: bool = True + gradient_checkpointing_block_size: int = 5 + scan_layers: bool = True + + use_bias: bool = False + rope_scaling: Optional[dict] = None + + # Axis + Pos = property(lambda self: Axis(name="position", size=self.seq_len)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + Factor = property(lambda self: Axis(name="factor", size=self.factor_dim)) + Embed = property(lambda self: Axis(name="embed", size=self.hidden_dim)) + Heads = property(lambda self: Axis(name="heads", size=self.num_heads)) + KVHeads = property(lambda self: Axis(name="kv_heads", size=self.num_kv_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_dim)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_dim // self.num_heads)) + + def __post_init__(self): + assert ( + self.num_heads % self.num_kv_heads == 0 + ), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." + + @property + def default_hf_checkpoint_converter(self) -> HFCheckpointConverter["FactorizedLlamaConfig"]: # type: ignore + assert self.reference_checkpoint, "Must specify HF model id to convert from." + return HFCheckpointConverter( + self.__class__, # type: ignore + self.reference_checkpoint, + trust_remote_code=True, + HfConfigClass=HfLlamaConfig, + ) + + @classmethod + def from_hf_config(cls, hf_config: HfConfig): + return FactorizedLlamaConfig( + seq_len=hf_config.max_position_embeddings, + hidden_dim=hf_config.hidden_size, + intermediate_dim=hf_config.intermediate_size, + num_layers=hf_config.num_hidden_layers, + num_heads=hf_config.num_attention_heads, + num_kv_heads=hf_config.num_key_value_heads, + activation_function=hf_config.hidden_act, + initializer_range=hf_config.initializer_range, + layer_norm_epsilon=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + factor_dim=getattr(hf_config, "factor_dim", 512), + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfLlamaConfig: + """Convert to HuggingFace's FactorizedLlamaConfig + + Args: + vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. + config_overrides (dict, optional): Overrides for the config. Defaults to None. + + Returns: + HfLlamaConfig: HuggingFace's FactorizedLlamaConfig + """ + if config_overrides is None: + config_overrides = {} + + return HfLlamaConfig( + max_position_embeddings=self.seq_len, + hidden_size=self.hidden_dim, + intermediate_size=self.intermediate_dim, + num_hidden_layers=self.num_layers, + num_attention_heads=self.num_heads, + num_key_value_heads=self.num_kv_heads, + hidden_act=self.activation_function, + initializer_range=self.initializer_range, + rms_norm_eps=self.layer_norm_epsilon, + rope_scaling=self.rope_scaling, + factor_dim=self.factor_dim, + vocab_size=vocab_size, + **config_overrides, + ) + + @property + def model_type(cls) -> Type["FactorizedLlamaLMHeadModel"]: + return FactorizedLlamaLMHeadModel + + +def low_rank_approximation(matrix, rank): + """ + Approximates the input matrix using Singular Value Decomposition (SVD) to a lower rank. + + Args: + matrix (jax.numpy.ndarray): Input matrix of shape (N, M) + rank (int): Desired rank of the approximation (H) + + Returns: + jax.numpy.ndarray, jax.numpy.ndarray: Two matrices of shape (N, H) and (H, M) + """ + from jax.numpy.linalg import svd + + # Perform SVD + U, S, Vh = svd(matrix, full_matrices=False) + + S = S[..., :rank] + # Truncate U, S, and Vh to the desired rank + U_truncated = U[..., :rank] + if len(S.shape) == 1: + S_truncated = jnp.diag(S) + else: + S_truncated = jax.vmap(jnp.diag, 0)(S[..., :rank]) + + Vh_truncated = Vh[..., :rank, :] # Note: Vh is already the conjugate transpose of V + + # Reconstruct the low-rank approximation + down_proj = U_truncated @ S_truncated # Shape (N, H) + up_proj = Vh_truncated # Shape (H, M) + + return down_proj, up_proj + + +class FactorizedLinear(StateDictSerializationMixin, eqx.Module): + """Factorized Linear Layer""" + + down_proj: hnn.Linear + up_proj: hnn.Linear + out_first: bool = eqx.static_field() + + @staticmethod + def init( + Out: Axis, In: Axis, Hidden: Axis, *, key, use_bias: bool = False, out_first: bool = False + ) -> "FactorizedLinear": + assert Hidden.size <= np.prod([a.size for a in ensure_tuple(Out)]), ( + "Hidden size must be less than or equal to output size.", + Hidden, + Out, + ) + + k_down, k_up = jrandom.split(key, 2) + down_proj = hnn.Linear.init(Out=Hidden, In=In, key=k_up, use_bias=use_bias) + up_proj = hnn.Linear.init(Out=Out, In=Hidden, key=k_down, use_bias=use_bias) + return FactorizedLinear(down_proj, up_proj, out_first=out_first) + + def __call__(self, x: NamedArray, *, key=None) -> NamedArray: + k_down, k_up = maybe_rng_split(key, 2) + return self.up_proj(self.down_proj(x, key=k_up), key=k_down) + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + d = state_dict.copy() + + # Initial factorized linear with the SVD approximation of the original input matrix. + weights = state_dict[prefix + ".weight"] + down_proj, up_proj = low_rank_approximation(weights, self.down_proj.Out.size) + if self.out_first: + d[prefix + ".down_proj.weight"] = up_proj + d[prefix + ".up_proj.weight"] = down_proj + else: + d[prefix + ".down_proj.weight"] = down_proj + d[prefix + ".up_proj.weight"] = up_proj + + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "down_proj"), + d, + self.down_proj, + out_dims_first_in_dict=self.out_first, + ) + ) + d.update( + unflatten_linear_layers( + apply_prefix(prefix, "up_proj"), + d, + self.up_proj, + out_dims_first_in_dict=self.out_first, + ) + ) + + return super().from_state_dict(d, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + # We override the default state dict generation, as we don't want to output + # weights for the factorized linear layers. + # super().update_state_dict(my_dict, prefix=prefix) + my_dict: StateDict = {} + + my_dict.update( + flatten_linear_layers(prefix + ".down_proj", self.down_proj, out_dims_first_in_dict=self.out_first) + ) + my_dict.update(flatten_linear_layers(prefix + ".up_proj", self.up_proj, out_dims_first_in_dict=self.out_first)) + + if self.out_first: + my_dict[prefix + ".weight"] = jnp.transpose( + jnp.dot( + jnp.transpose(my_dict[prefix + ".down_proj.weight"]), + jnp.transpose(my_dict[prefix + ".up_proj.weight"]), + ) + ) + else: + my_dict[prefix + ".weight"] = jnp.dot( + my_dict[prefix + ".down_proj.weight"], + my_dict[prefix + ".up_proj.weight"], + ) + + del my_dict[prefix + ".down_proj.weight"] + del my_dict[prefix + ".up_proj.weight"] + + state_dict.update(my_dict) + return state_dict + + +class FactorizedLlamaMlp(eqx.Module, StateDictSerializationMixin): + """Multi-layer Perceptron + In comparison with GPT2, FactorizedLlamaMlp adds an up-proj that multiplies with activated gate_proj, + before down-proj. + """ + + gate_proj: FactorizedLinear # projection from Embed to Mlp + up_proj: FactorizedLinear # projection from Embed to Mlp + down_proj: FactorizedLinear # projection from Mlp to Embed + act: Callable = eqx.static_field() + + @staticmethod + def init( + Embed: Axis, Mlp: Axis, Factor: Axis, activation_fn: Union[str, Callable], *, key, use_bias: bool = False + ) -> "FactorizedLlamaMlp": + k_fc, k_up_proj, k_down_proj = jrandom.split(key, 3) + gate_proj = FactorizedLinear.init( + Out=Mlp, In=Embed, Hidden=Factor, key=k_fc, use_bias=use_bias, out_first=True + ) + up_proj = FactorizedLinear.init( + Out=Mlp, In=Embed, Hidden=Factor, key=k_up_proj, use_bias=use_bias, out_first=True + ) + down_proj = FactorizedLinear.init( + Out=Embed, In=Mlp, Hidden=Factor, key=k_down_proj, use_bias=use_bias, out_first=True + ) + if isinstance(activation_fn, str): + activation_fn = ACT2FN[activation_fn] + act = activation_fn # type: ignore + return FactorizedLlamaMlp(gate_proj, up_proj, down_proj, act) + + @named_call + def __call__(self, x: NamedArray, *, key=None) -> NamedArray: + k_gate, k_up, k_down = maybe_rng_split(key, 3) + hidden_states = self.gate_proj(x, key=k_gate) + hidden_states = self.act(hidden_states) + hidden_states = hidden_states * self.up_proj(x, key=k_up) + outputs = self.down_proj(hidden_states, key=k_down) + return outputs + + +class FactorizedLlamaAttention(StateDictSerializationMixin, eqx.Module): + config: FactorizedLlamaConfig = eqx.static_field() + q_proj: FactorizedLinear # projection from Embed to query + k_proj: FactorizedLinear # projection from Embed to key + v_proj: FactorizedLinear # projection from Embed to value + o_proj: FactorizedLinear # projection from Heads to output + + @staticmethod + def init(config: FactorizedLlamaConfig, *, key) -> "FactorizedLlamaAttention": + use_bias = config.use_bias + Embed = config.Embed + QHeadsPerGroup = hax.Axis("q_heads_per_group", config.num_heads // config.num_kv_heads) + + k_q, k_k, k_v, k_o = jrandom.split(key, 4) + q_proj = FactorizedLinear.init( + In=Embed, + Hidden=config.Factor, + Out=(config.KVHeads, QHeadsPerGroup, config.HeadSize), + key=k_q, + use_bias=use_bias, + ) + k_proj = FactorizedLinear.init( + In=Embed, + Hidden=config.Factor, + Out=(config.KVHeads, config.HeadSize), + key=k_k, + use_bias=use_bias, + out_first=True, + ) + v_proj = FactorizedLinear.init( + In=Embed, + Hidden=config.Factor, + Out=(config.KVHeads, config.HeadSize), + key=k_v, + use_bias=use_bias, + out_first=True, + ) + o_proj = FactorizedLinear.init( + In=(config.Heads, config.HeadSize), Hidden=config.Factor, Out=Embed, key=k_o, use_bias=use_bias + ) + return FactorizedLlamaAttention(config, q_proj, k_proj, v_proj, o_proj) + + @named_call + def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, key=None) -> NamedArray: + key_q, key_k, key_v, key_o = maybe_rng_split(key, 4) + + # reorder heads and position for better training throughput + q = self.q_proj(x, key=key_q).rearrange((..., "kv_heads", "q_heads_per_group", "position", "head_size")) + k = self.k_proj(x, key=key_k).rearrange((..., "kv_heads", "position", "head_size")) + v = self.v_proj(x, key=key_v).rearrange((..., "kv_heads", "position", "head_size")) + + cos, sin = llama_rotary_pos_emb(self.config.HeadSize, x.resolve_axis("position")) + q, k = _apply_rotary_pos_emb(q, k, cos, sin) + + k = k.rename({"position": "key_position"}) + v = v.rename({"position": "key_position"}) + + c = self.config + attn_output = dot_product_attention( + "position", + "key_position", + "head_size", + q, + k, + v, + mask, + attention_dtype=jnp.float32 if self.config.upcast_attn else x.dtype, + use_flash=c.use_flash_attention, + flash_block_size=c.flash_attention_block_size, + ) + + attn_output = attn_output.flatten_axes(("kv_heads", "q_heads_per_group"), "heads") + attn_output = attn_output.astype(x.dtype) + + attn_output = self.o_proj(attn_output, key=key_o) + return attn_output + + +class FactorizedLlamaDecoderLayer(StateDictSerializationMixin, eqx.Module): + config: FactorizedLlamaConfig = eqx.static_field() + self_attn: FactorizedLlamaAttention + mlp: FactorizedLlamaMlp + input_layernorm: LlamaRMSNorm + post_attention_layernorm: LlamaRMSNorm + + @staticmethod + def init(config: FactorizedLlamaConfig, *, key) -> "FactorizedLlamaDecoderLayer": + k_attn, k_mlp = jrandom.split(key, 2) + + attn = FactorizedLlamaAttention.init(config=config, key=k_attn) + mlp = FactorizedLlamaMlp.init( + Embed=config.Embed, + Mlp=config.Mlp, + Factor=config.Factor, + activation_fn=config.activation_function, + key=k_mlp, + use_bias=config.use_bias, + ) + ln_1 = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + ln_2 = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + + return FactorizedLlamaDecoderLayer(config, attn, mlp, ln_1, ln_2) + + @named_call + def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *, key=None) -> NamedArray: + k_attn, k_mlp = maybe_rng_split(key, 2) + # self attention and skip connection + residual = x + x = self.input_layernorm(x) + attn_output = self.self_attn(x=x, mask=mask, key=k_attn) + x = residual + attn_output + + # MLP and skip connection + residual = x + x = self.post_attention_layernorm(x) + mlp_output = self.mlp(x, key=k_mlp) + output = residual + mlp_output + return output + + +class FactorizedLlamaTransformer(StateDictSerializationMixin, eqx.Module): + config: FactorizedLlamaConfig = eqx.static_field() + layers: BlockFoldable[FactorizedLlamaDecoderLayer] + norm: LlamaRMSNorm + + @staticmethod + def init(config: FactorizedLlamaConfig, *, key) -> "FactorizedLlamaTransformer": + S = Stacked + if not config.scan_layers: + from haliax.nn.scan import BlockSeq + + S = BlockSeq + + layers = S.init( + config.Layers, FactorizedLlamaDecoderLayer, gradient_checkpointing=config.gradient_checkpointing + )( + config, + key=shaped_rng_split(key, config.num_layers), + ) + ln_f = LlamaRMSNorm.init(config.Embed, eps=config.layer_norm_epsilon, use_bias=config.use_bias) + + return FactorizedLlamaTransformer(config, layers, ln_f) + + @named_call + def __call__(self, x: NamedArray, attn_mask: Optional[NamedArray | AttentionMask], *, key) -> NamedArray: + keys = maybe_rng_split(key, self.config.num_layers) if key is not None else None + x = self.layers.fold(x, mask=attn_mask, key=keys) + x = self.norm(x) + + return x + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + if isinstance(self.layers, Stacked): + state_dict = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layers")) + + out = super().from_state_dict(state_dict, prefix=prefix) + return out + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_state_dict: StateDict = {} + super().update_state_dict(my_state_dict, prefix=prefix) + + if isinstance(self.layers, Stacked): + stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layers")) + state_dict.update(stacked_dict) + else: + state_dict.update(my_state_dict) + + return state_dict + + +class FactorizedLlamaEmbedding(StateDictSerializationMixin, eqx.Module): + """Similar to GPT2 Embedding, except that: + - FactorizedLlama doesn't have position embedding in the Embedding layer. + - FactorizedLlama doesn't use dropout. + """ + + Vocab: Axis = eqx.static_field() + config: FactorizedLlamaConfig = eqx.static_field() + token_embeddings: NamedArray + + @staticmethod + def init(Vocab: Axis, config: FactorizedLlamaConfig, *, key) -> "FactorizedLlamaEmbedding": + token_embeddings = hax.random.normal(key, (Vocab, config.Embed)) + return FactorizedLlamaEmbedding(Vocab, config, token_embeddings) + + @named_call + def embed(self, input_ids, *args): + input_embeds = self.token_embeddings.take("vocab", input_ids) + x = input_embeds + return x + + def unembed(self, x: NamedArray): + return hax.dot("embed", x, self.token_embeddings) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"token_embeddings": "model.embed_tokens.weight"} + + def resize_embeddings(self, new_size: int, key: Optional[PRNGKeyArray] = None): + new_weights = hax.tree_util.resize_axis(self.token_embeddings, self.Vocab, new_size, key=key) + return dataclasses.replace(self, Vocab=self.Vocab.resize(new_size), token_embeddings=new_weights) + + +class FactorizedLlamaLMHeadModel(eqx.Module, LmHeadModel[FactorizedLlamaConfig], StateDictSerializationMixin): + transformer: FactorizedLlamaTransformer + embeddings: FactorizedLlamaEmbedding + lm_head: FactorizedLinear + + @property + def config(self): + return self.transformer.config + + @property + def vocab_size(self) -> int: + return self.Vocab.size + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + @classmethod + def init(cls, Vocab: Axis, config: FactorizedLlamaConfig, *, key) -> "FactorizedLlamaLMHeadModel": + k_t, k_emb = jrandom.split(key, 2) + transformer = FactorizedLlamaTransformer.init(config, key=k_t) + embeddings = FactorizedLlamaEmbedding.init(Vocab, config, key=k_emb) + lm_head = FactorizedLinear.init( + In=config.Embed, Hidden=config.Factor, Out=Vocab, key=k_emb, use_bias=False, out_first=True + ) + return FactorizedLlamaLMHeadModel(transformer, embeddings, lm_head) + + def __call__( + self, + input_ids: NamedArray, + attn_mask: Optional[Union[NamedArray, AttentionMask]] = None, + *, + key=None, + ) -> NamedArray: + """ + Args: + input_ids (NamedArray): [batch, position] + Indices of input sequence tokens in the vocabulary. + attn_mask (Union[NamedArray, AttentionMask], optional): [batch, position] + Mask to avoid performing attention on the padding token indices of the encoder input. + The attn_mask from training pipeline may be an AttentionMask object instead of NamedArray + """ + k_t, k_head = maybe_rng_split(key, 2) + x = self.embeddings.embed(input_ids) + x = self.transformer(x, attn_mask=attn_mask, key=k_t) + lm_logits = self.lm_head(x, key=k_head) + return lm_logits + + def resize_vocab(self, new_size: int, key=None) -> "LmHeadModel[FactorizedLlamaConfig]": + new_Vocab = self.Vocab.resize(new_size) + k1, k2 = maybe_rng_split(key, 2) + new_embeddings = self.embeddings.resize_embeddings(new_size, key=k1) + new_lm_matrix = hax.tree_util.resize_axis(self.lm_head.weight, self.Vocab, new_size, key=k2) + new_lm_head = dataclasses.replace(self.lm_head, Out=new_Vocab, weight=new_lm_matrix) + + return dataclasses.replace(self, embeddings=new_embeddings, lm_head=new_lm_head) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"transformer": "model", "embeddings": None} + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + return super().from_state_dict(state_dict, prefix) + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + super().update_state_dict(state_dict, prefix=prefix) + + +def _rotate_half(x: NamedArray) -> NamedArray: + """Rotates half of the hidden dims of the input and concatenates them.""" + HeadSize = x.axes[-1] + x1 = x[HeadSize, : HeadSize.size // 2] + x2 = x[HeadSize, HeadSize.size // 2 :] + out = hax.concatenate(HeadSize, (-x2, x1)) + return out + + +def _apply_rotary_pos_emb( + q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size] + k: NamedArray, # [batch, position, kv_heads, head_size] + cos: NamedArray, # [position, head_size] + sin: NamedArray, # [position, head_size] +) -> Tuple[NamedArray, NamedArray]: + """Applies rotary position embedding to q and k.""" + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin + return q_embed, k_embed + + +def llama_rotary_pos_emb(HeadSize: Axis, Pos: Axis, base: int = 10000) -> Tuple[NamedArray, NamedArray]: + with jax.ensure_compile_time_eval(): + HeadHalfSize = HeadSize.resize(HeadSize.size // 2) + inv_freq: NamedArray = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) + + position_ids: NamedArray = hax.arange(Pos) + + freqs = position_ids * inv_freq.broadcast_axis(Pos) + # This is different from the paper but aligns with HF implementation: + # It uses a different permutation in order to obtain the same calculation + emb = hax.concatenate(HeadSize, (freqs, freqs)) + cos = hax.cos(emb) + sin = hax.sin(emb) + # This is different from the paper but aligns with HF implementation: + return cos, sin diff --git a/src/levanter/models/llama.py b/src/levanter/models/llama.py index c0e1ca45a..b023face2 100644 --- a/src/levanter/models/llama.py +++ b/src/levanter/models/llama.py @@ -54,6 +54,7 @@ class LlamaConfig(HFCompatConfig): activation_function (str, optional): activation function for the hidden layer. Defaults to "silu". rope_scaling (Dict, optional): dict containing the scaling configuration for the Rotary Positional Embedding. """ + reference_checkpoint: str = "meta-llama/Llama-2-7b-hf" seq_len: int = 2048 hidden_dim: int = 4096 @@ -93,11 +94,11 @@ def __post_init__(self): self.num_heads % self.num_kv_heads == 0 ), f"num_heads={self.num_heads} not divisible by num_kv_heads={self.num_kv_heads}." - @cached_classproperty - def default_hf_checkpoint_converter(cls) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore + @property + def default_hf_checkpoint_converter(self) -> HFCheckpointConverter["LlamaConfig"]: # type: ignore return HFCheckpointConverter( - cls, # type: ignore - "meta-llama/Llama-2-7b-hf", + self.__class__, # type: ignore + self.reference_checkpoint, trust_remote_code=True, tokenizer="hf-internal-testing/llama-tokenizer", HfConfigClass=HfLlamaConfig, @@ -501,7 +502,7 @@ def __call__( Mask to avoid performing attention on the padding token indices of the encoder input. The attn_mask from training pipeline may be an AttentionMask object instead of NamedArray """ - k_t, k_head = maybe_rng_split(key, 2) + k_t, k_head = jrandom.split(key, 2) x = self.embeddings.embed(input_ids) x = self.transformer(x, attn_mask=attn_mask, key=k_t) lm_logits = self.lm_head(x, key=k_head) diff --git a/src/levanter/tracker/tracker.py b/src/levanter/tracker/tracker.py index 8b6816f17..ad81c43ac 100644 --- a/src/levanter/tracker/tracker.py +++ b/src/levanter/tracker/tracker.py @@ -82,8 +82,10 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio tracker.log_artifact(artifact_path, name=name, type=type) +@dataclasses.dataclass class TrackerConfig(draccus.PluginRegistry, abc.ABC): discover_packages_path = "levanter.tracker" + id: Optional[str] = None @abc.abstractmethod def init(self, run_id: Optional[str]) -> Tracker: @@ -94,8 +96,10 @@ def default_choice_name(cls) -> Optional[str]: return "wandb" +@dataclasses.dataclass class NoopTracker(Tracker): name: str = "noop" + id: Optional[str] = None def log_hyperparameters(self, hparams: dict[str, Any]): pass @@ -104,7 +108,9 @@ def log(self, metrics: dict[str, Any], *, step, commit: Optional[bool] = None): pass def log_summary(self, metrics: dict[str, Any]): - pass + print("Summary:") + for k, v in metrics.items(): + print(f"-- {k}: {v}") def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optional[str] = None): pass @@ -114,4 +120,4 @@ def log_artifact(self, artifact_path, *, name: Optional[str] = None, type: Optio @dataclasses.dataclass class NoopConfig(TrackerConfig): def init(self, run_id: Optional[str]) -> Tracker: - return NoopTracker() + return NoopTracker(id=run_id) diff --git a/src/levanter/trainer_state.py b/src/levanter/trainer_state.py index 15800bd17..976c802eb 100644 --- a/src/levanter/trainer_state.py +++ b/src/levanter/trainer_state.py @@ -165,7 +165,7 @@ def saveable_training_mask(trainer_state: S, is_trainable_param: FilterTree = Tr is_trainable_param = make_floating_point_trainable_filter(is_trainable_param) trainer_state = jax.tree_util.tree_map(lambda x: True, trainer_state) - saveable_state = dataclasses.replace(trainer_state, model=is_trainable_param) # type: ignore + saveable_state = dataclasses.replace(trainer_state, student=is_trainable_param, teacher=False) # type: ignore return saveable_state # type: ignore diff --git a/tests/test_factorized_llama.py b/tests/test_factorized_llama.py new file mode 100644 index 000000000..61c9b5197 --- /dev/null +++ b/tests/test_factorized_llama.py @@ -0,0 +1,281 @@ +import tempfile + +import draccus +import equinox as eqx +import jax +import numpy as np +import pytest +import transformers +from jax import random + +import haliax as hax + +from levanter.models.attention import AttentionMask +from levanter.models.factorized_llama import ( + FactorizedLlamaAttention, + FactorizedLlamaConfig, + FactorizedLlamaDecoderLayer, + FactorizedLlamaLMHeadModel, +) +from levanter.utils.jax_utils import parameter_count +from test_utils import check_load_config, check_model_works_with_seqlen, parameterize_with_configs, skip_if_no_torch + +import logging + +logging.basicConfig(level=logging.INFO) + + +def _get_config(use_flash=False, num_kv_heads=4, seq_len=128) -> FactorizedLlamaConfig: + rope_scaling = { + "type": "linear", + "factor": 2.0, + } + return FactorizedLlamaConfig( + seq_len=seq_len, + hidden_dim=16 * 4, + num_heads=4, + factor_dim=8, + num_kv_heads=num_kv_heads, + rope_scaling=rope_scaling, + gradient_checkpointing=False, # disable for tests so debugging is easier + use_flash_attention=use_flash, + flash_attention_block_size=8 if use_flash else None, + ) + + +@skip_if_no_torch +def test_factor_llama_config(): + # load HF config and convert to levanter config + hf_config = transformers.LlamaConfig.from_pretrained("meta-llama/FactorizedLlama-2-7b-hf") + llama_config = FactorizedLlamaConfig.from_hf_config(hf_config) + + # convert back to HF config + config_overrides = { + "_name_or_path": hf_config._name_or_path, + "architectures": hf_config.architectures, + "torch_dtype": hf_config.torch_dtype, + } + new_hf_config = llama_config.to_hf_config( + vocab_size=hf_config.vocab_size, + config_overrides=config_overrides, + ) + + # assert the content in new_hf_config is the same as hf_config + for k in new_hf_config.__dict__.keys(): + if k in ["_commit_hash", "transformers_version"]: + continue + assert getattr(new_hf_config, k) == getattr( + hf_config, k + ), f"{k} {getattr(new_hf_config, k)} != {getattr(hf_config, k)}" + + +@skip_if_no_torch +@pytest.mark.parametrize("use_flash", [True, False]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_factor_llama_attention(use_flash, num_kv_heads): + import torch + from transformers.models.llama.modeling_llama import LlamaAttention as HFLlamaAttention + + config = _get_config(use_flash=use_flash, num_kv_heads=num_kv_heads) + + attention = FactorizedLlamaAttention.init(config=config, key=random.PRNGKey(0)) + + state = attention.to_state_dict() + state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} + hf_attention = HFLlamaAttention(config.to_hf_config(32000)) + hf_attention.load_state_dict(state, strict=True) + + x, mask = _get_random_inputs(config) + x_torch = torch.from_numpy(np.array(x.array)) + batch_size = x_torch.shape[0] + explicit_mask = torch.from_numpy(np.array(mask.materialize(config.Pos, config.KeyPos).array)) + mask_torch = explicit_mask.broadcast_to((batch_size, 1, -1, -1)) + + # the torch mask is really a bias, so we need to invert it and make it a big negative number + mask_torch = (mask_torch == 0).float() * -1e9 + + out = attention(x, mask) + position_ids = torch.arange(config.Pos.size).reshape(1, -1) + hf_out = hf_attention(x_torch, position_ids=position_ids, attention_mask=mask_torch) + hf_out = hf_out[0].detach().cpu().numpy() + out = np.array(out.array) + + assert np.isclose( + hf_out, out, rtol=1e-3, atol=1e-3 + ).all(), f"Diff: {hf_out - out}, Max: {np.max(np.abs(hf_out - out))}" + + +def test_factor_llama_param_counts_dont_change_with_seqlen(): + model = FactorizedLlamaLMHeadModel.init(hax.Axis("v", 2048), _get_config(seq_len=128), key=random.PRNGKey(0)) + model2 = FactorizedLlamaLMHeadModel.init(hax.Axis("v", 2048), _get_config(seq_len=256), key=random.PRNGKey(0)) + assert parameter_count(model) == parameter_count(model2) + + +@skip_if_no_torch +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_factor_llama_decoder_layer(num_kv_heads): + import torch + from transformers.models.llama.modeling_llama import LlamaDecoderLayer as HFLlamaDecoderLayer + + llama_config = _get_config(num_kv_heads=num_kv_heads) + key = random.PRNGKey(0) + llama_decoder_layer = FactorizedLlamaDecoderLayer.init(config=llama_config, key=key) + + state = llama_decoder_layer.to_state_dict() + state = {k: torch.from_numpy(np.array(v)) for k, v in state.items()} + print(state.keys()) + hf_decoder_layer = HFLlamaDecoderLayer(llama_config.to_hf_config(32000), layer_idx=0) + hf_decoder_layer.load_state_dict(state, strict=True) + + x, mask = _get_random_inputs(llama_config) + x_torch = torch.from_numpy(np.array(x.array)) + batch_size = x_torch.shape[0] + explicit_mask = torch.from_numpy(np.array(mask.materialize(llama_config.Pos, llama_config.KeyPos).array)) + mask_torch = explicit_mask.broadcast_to((batch_size, 1, -1, -1)) + mask_torch = (mask_torch == 0).float() * -1e9 + + position_ids = torch.arange(llama_config.Pos.size).reshape(1, -1) + + out = llama_decoder_layer(x, mask) + hf_out = hf_decoder_layer(x_torch, position_ids=position_ids, attention_mask=mask_torch) + + assert np.isclose( + hf_out[0].detach().cpu().numpy(), np.array(out.array), rtol=1e-2, atol=1e-2 + ).all(), f"{hf_out[0]} != {out}" + + +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_factor_llama_lm_head_model(num_kv_heads): + llama_config = _get_config(num_kv_heads=num_kv_heads) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = llama_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = AttentionMask.causal() + + llama_model = FactorizedLlamaLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) + out = llama_model(input_ids, mask) + assert out.array.shape == (Batch.size, Pos.size, Vocab.size) + + +@pytest.mark.parametrize("use_flash", [True, False]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_factor_llama_lm_head_model_bwd(use_flash, num_kv_heads): + llama_config = _get_config(use_flash=use_flash, num_kv_heads=num_kv_heads) + Batch = hax.Axis("batch", 2) + Vocab = hax.Axis("vocab", 1000) + Pos = llama_config.Pos + input_ids = hax.random.randint(random.PRNGKey(0), (Batch, Pos), 0, Vocab.size) + mask = AttentionMask.causal() + + llama_model = FactorizedLlamaLMHeadModel.init(Vocab=Vocab, config=llama_config, key=random.PRNGKey(0)) + + def f(llama_model, input_ids, mask): + out = llama_model(input_ids, mask) + return hax.sum(out).scalar() + + _, grads = eqx.filter_value_and_grad(f)(llama_model, input_ids, mask) + + +@skip_if_no_torch +@pytest.mark.parametrize("scan_layers", [True, False]) +@pytest.mark.parametrize("num_kv_heads", [1, 2, 4]) +def test_factor_llama_roundtrip(scan_layers, num_kv_heads): + import torch + from transformers import AutoModelForCausalLM, LlamaForCausalLM + + converter = FactorizedLlamaConfig.default_hf_checkpoint_converter + + config = FactorizedLlamaConfig( + seq_len=128, + hidden_dim=64, + factor_dim=16, + num_heads=4, + num_kv_heads=num_kv_heads, + gradient_checkpointing=False, + num_layers=4, + scan_layers=scan_layers, + ) + Vocab = hax.Axis("vocab", 1000) + hf_config = config.to_hf_config(Vocab.size) + + # Make input and attn_mask + input = hax.random.randint(random.PRNGKey(0), config.Pos, 0, Vocab.size) + attn_mask = AttentionMask.causal() + input_torch = torch.from_numpy(np.array(input.array)).to(torch.int32).unsqueeze(0) + + torch.random.manual_seed(0) + + torch_model = LlamaForCausalLM(hf_config) + torch_model.eval() + + torch_out = torch_model(input_torch) + torch_out = torch_out.logits[0].detach().cpu().numpy() + torch_out = jax.nn.softmax(torch_out, axis=-1) + + tmpdir = tempfile.mkdtemp() + print("Temp dir: ", tmpdir) + torch_model.save_pretrained(f"{tmpdir}/torch_model") + + model = converter.load_pretrained( + FactorizedLlamaLMHeadModel, f"{tmpdir}/torch_model", resize_vocab_to_match_tokenizer=False + ) + + def compute(input): + model_output = model(input, attn_mask=attn_mask) + return hax.nn.softmax(model_output, axis=model.Vocab) + + compute = jax.jit(compute) + jax_out = compute(input).array + + assert torch_out.shape == jax_out.shape, f"{torch_out.shape} != {jax_out.shape}" + assert np.isclose(torch_out, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out} != {jax_out}" + + converter.save_pretrained(model, f"{tmpdir}/lev_model", save_reference_code=False) + torch_model2 = AutoModelForCausalLM.from_pretrained(f"{tmpdir}/lev_model") + torch_model2.eval() + + torch_out2 = torch_model2(input_torch) + torch_out2 = torch_out2.logits[0].detach().cpu().numpy() + torch_out2 = jax.nn.softmax(torch_out2, axis=-1) + assert torch_out2.shape == jax_out.shape, f"{torch_out2.shape} != {jax_out.shape}" + assert np.isclose(torch_out2, np.array(jax_out), rtol=1e-2, atol=1e-2).all(), f"{torch_out2} != {jax_out}" + + +def _get_random_inputs(config: FactorizedLlamaConfig): + Embed = config.Embed + Pos = config.Pos + Batch = hax.Axis("batch", 2) + x = hax.random.normal(random.PRNGKey(0), (Batch, Pos, Embed)) + mask = AttentionMask.causal() + return x, mask + + +@parameterize_with_configs("distill_llama*.yaml") +def test_factor_llama_configs(config_file): + from levanter.main.train_distill_lm import TrainDistillLmConfig + + check_load_config(config_class=TrainDistillLmConfig, config_file=config_file) + + +@parameterize_with_configs("distill_llama*.yaml") +def test_factor_llama_trainer_init(config_file): + from levanter.main.train_distill_lm import TrainDistillLmConfig, main + + config = draccus.parse(TrainDistillLmConfig, config_file, args=[]) + main(config) + + +@pytest.mark.parametrize("num_kv_heads", [1, 2]) +def test_pass_different_length_seq(num_kv_heads): + config = FactorizedLlamaConfig( + seq_len=64, + hidden_dim=64, + factor_dim=16, + intermediate_dim=32, + num_layers=2, + num_heads=2, + num_kv_heads=num_kv_heads, + use_flash_attention=True, + ) + check_model_works_with_seqlen(FactorizedLlamaLMHeadModel, config, 16)