diff --git a/README.md b/README.md index 7e4238e32..caa67a74c 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,7 @@ If you're using a TPU, more complete documentation for setting that up is availa As a kind of hello world, here's how you can train a GPT-2 "nano"-sized model on a small dataset. ```bash -python -m levanter.main.train_lm --config_path config/gpt2_nano.yaml +python -m levanter.main.train_lm --config_path config/llama2_100M_kron_test.yaml # alternatively, if you didn't use -e and are in a different directory python -m levanter.main.train_lm --config_path gpt2_nano diff --git a/config/llama2_100M_kron_test.yaml b/config/llama2_100M_kron_test.yaml new file mode 100644 index 000000000..9993da34d --- /dev/null +++ b/config/llama2_100M_kron_test.yaml @@ -0,0 +1,34 @@ +data: + id: openwebtext +model: + type: llama + seq_len: 4096 + hidden_dim: 768 + intermediate_dim: 3072 + num_layers: 12 + num_heads: 12 + num_kv_heads: 12 +trainer: + tracker: + project: "levanter" + tags: ["pile", "llama"] + mp: p=f32,c=bfloat16 + model_axis_size: 1 + checkpointer: + keep: + - every: 1000 + save_interval: 30m + + + train_batch_size: 1024 + per_device_parallelism: 32 # set for v3 TPU + per_device_eval_parallelism: 32 # set a larger batch size for eval + num_train_steps: 50001 +optimizer: + learning_rate: 3E-4 + weight_decay: 0.1 + warmup: 2000 + cooldown: 0.1 + lr_schedule: constant + min_lr_ratio: 0.0 + type: kron diff --git a/config/llama2_100M_mars.yaml b/config/llama2_100M_mars.yaml new file mode 100644 index 000000000..2c062d816 --- /dev/null +++ b/config/llama2_100M_mars.yaml @@ -0,0 +1,34 @@ +data: !include data/dclm_gpt_neo.yaml +model: + type: llama + seq_len: 4096 + hidden_dim: 768 + intermediate_dim: 3072 + num_layers: 12 + num_heads: 12 + num_kv_heads: 12 +trainer: + tracker: + project: "levanter" + tags: ["pile", "llama"] + mp: p=f32,c=bfloat16 + model_axis_size: 1 + checkpointer: + keep: + - every: 1000 + save_interval: 30m + + + train_batch_size: 1024 + per_device_parallelism: 4 # set for v3 TPU + per_device_eval_parallelism: 4 # set a larger batch size for eval + num_train_steps: 50001 +optimizer: + learning_rate: 4E-3 # set low for fine-tuning + weight_decay: 0.1 + min_lr_ratio: 0.0 + warmup: 2000 + cooldown: 0.4 + lr_schedule: constant + gamma: 0.025 + type: mars diff --git a/config/llama2_100M_muon.yaml b/config/llama2_100M_muon.yaml new file mode 100644 index 000000000..3f8194465 --- /dev/null +++ b/config/llama2_100M_muon.yaml @@ -0,0 +1,34 @@ +data: !include data/dclm_gpt_neo.yaml +model: + type: llama + seq_len: 4096 + hidden_dim: 768 + intermediate_dim: 3072 + num_layers: 12 + num_heads: 12 + num_kv_heads: 12 +trainer: + tracker: + project: "levanter" + tags: ["pile", "llama"] + mp: p=f32,c=bfloat16 + model_axis_size: 1 + checkpointer: + keep: + - every: 1000 + save_interval: 30m + + + train_batch_size: 1024 + per_device_parallelism: 4 # set for v3 TPU + per_device_eval_parallelism: 4 # set a larger batch size for eval + num_train_steps: 50001 +optimizer: + learning_rate: 2E-2 # set low for fine-tuning + weight_decay: 0 + warmup: 0 + cooldown: 0.1 + lr_schedule: constant + min_lr_ratio: 0.0 + max_grad_norm: 0.0 + type: muon diff --git a/src/levanter/data/audio.py b/src/levanter/data/audio.py index b2235e863..46e72b210 100644 --- a/src/levanter/data/audio.py +++ b/src/levanter/data/audio.py @@ -193,7 +193,7 @@ def decode(x): def doc_iterator(self, split: str) -> Iterator[Tuple[np.ndarray, int, str]]: if self.id is not None: - data = datasets.load_dataset(self.id, split=split, name=self.name, streaming=self.stream) + data = datasets.load_dataset(self.id, split=split, name=self.name, streaming=self.stream, trust_remote_code=True) for doc in data: yield (doc[self.audio_key]["array"], doc[self.audio_key]["sampling_rate"], doc[self.text_key]) else: @@ -385,7 +385,7 @@ def _has_validation_set(self): if self.id is not None: dataset = datasets.load_dataset( - self.id, name=self.name, streaming=self.stream, split=self.validation_split + self.id, name=self.name, streaming=self.stream, split=self.validation_split, trust_remote_code=True ) try: next(iter(dataset)) diff --git a/src/levanter/data/sharded_datasource.py b/src/levanter/data/sharded_datasource.py index 9dca9b618..30a7727d6 100644 --- a/src/levanter/data/sharded_datasource.py +++ b/src/levanter/data/sharded_datasource.py @@ -253,7 +253,7 @@ def open_shard_at_row(self, shard_name: str, row: int) -> Iterator[dict]: def _load_dataset(self): # obnoxiously, the dataset loading stuff doesn't work with ray because of multiprocessing # so we have to do this hacky thing where we load the dataset in the worker - return datasets.load_dataset(self.id, split=self.split, streaming=self.streaming, **self.kwargs) + return datasets.load_dataset(self.id, split=self.split, streaming=self.streaming, trust_remote_code=True, **self.kwargs) class TextUrlDataSource(ShardedDataSource[str]): diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 13c7ea44b..3c71e94df 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -594,7 +594,7 @@ def get_shard_source(self, split) -> Optional[ShardedDataSource[str]]: def doc_iterator(self, split: str): if self.id is not None: - dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream) + dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream, trust_remote_code=True) data = dataset[split] for doc in data: yield doc[self.text_key] @@ -1065,7 +1065,7 @@ def _has_validation_set(self): return True if self.id is not None: - dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream, split="validation") + dataset = datasets.load_dataset(self.id, name=self.name, streaming=self.stream, split="validation", trust_remote_code=True) try: next(iter(dataset)) return True diff --git a/src/levanter/optim/__init__.py b/src/levanter/optim/__init__.py index 7dec2ebb4..64a51ea2c 100644 --- a/src/levanter/optim/__init__.py +++ b/src/levanter/optim/__init__.py @@ -5,3 +5,12 @@ scale_by_sophia_g, scale_by_sophia_h, ) +from .muon import ( + MuonConfig, + ScaleByMuonState +) +from .mars import ( + MarsConfig, + ScaleByMarsState +) +from .kron import KronConfig diff --git a/src/levanter/optim/kron.py b/src/levanter/optim/kron.py new file mode 100644 index 000000000..57006059b --- /dev/null +++ b/src/levanter/optim/kron.py @@ -0,0 +1,818 @@ +from dataclasses import dataclass +from typing import Optional, Union + +import jax.numpy as jnp +import optax + +from levanter.optim.config import OptimizerConfig + + +@OptimizerConfig.register_subclass("kron") +@dataclass +class KronConfig(OptimizerConfig): + """Configuration for PSGD Kron optimizer. + + Attributes: + beta1: Momentum parameter. 0.9 or 0.95 are common values. + weight_decay: Weight decay coefficient. + max_grad_norm: Unused. + preconditioner_update_probability: Final probability of updating the preconditioner. Default + is 0.03 (update every 33 steps). The `precond_update_prob_schedule` holds probability at + 1.0 for `update_prob_flat_start` steps before annealing exponentially down to this + value within ~3000 steps. Training is slower while updates are done every step, but + training speeds up after update probability decays. + update_prob_flat_start: Number of steps to keep update probability at 1.0 before annealing. + Default value of 500 works well, but increasing this to 1000 or 2000 can benefit training. + However, this slows down training. A good balance is to keep update probability at 1.0 during + initial loss drop, then when you notice loss start to plateau, the preconditioner is mostly + learned and update probability can be decayed for faster training. + max_size_triangular: Max size for dim's preconditioner to be triangular. + min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners. + memory_save_mode: Memory saving mode for preconditioners. Options: + - None: All preconditioners are triangular (default) + - 'one_diag': Largest/last dim per layer uses diagonal preconditioner + - 'all_diag': All preconditioners are diagonal + preconditioner_lr: Learning rate for preconditioner. + preconditioner_init_scale: Scale for preconditioner initialization. + mu_dtype: Dtype of the momentum buffer. Defaults to same dtype as parameters. + precond_dtype: Dtype of the preconditioners. Defaults to 'float32'. + precond_update_precision: Precision for matmul during preconditioner update. + Options: 'bfloat16', 'tensorfloat32', 'float32'. + precond_grads_precision: Precision for matmul during preconditioning grads. + Options: 'bfloat16', 'tensorfloat32', 'float32'. + lax_map_scanned_layers: Whether to use lax.map for scanned layers instead of vmap. + Useful to save memory with large models. + lax_map_batch_size: Batch size for lax.map, see JAX docs for more info. + """ + # some of these are changed from kron defaults to better suit levanter + beta1: float = 0.9 + weight_decay: float = 0.1 + max_grad_norm: Optional[float] = None + preconditioner_update_probability: float = 0.03 + update_prob_flat_start: int = 1000 + max_size_triangular: int = 16384 + min_ndim_triangular: int = 2 + memory_save_mode: Optional[str] = None + preconditioner_lr: float = 0.1 + preconditioner_init_scale: float = 1.0 + mu_dtype: Optional[Union[str, jnp.dtype]] = None + precond_dtype: Optional[Union[str, jnp.dtype]] = None + precond_update_precision: Optional[str] = "tensorfloat32" + precond_grads_precision: Optional[str] = None + lax_map_scanned_layers: bool = False + lax_map_batch_size: int = 8 + + def build(self, num_train_steps): + """Creates the optimizer.""" + if self.max_grad_norm is not None and jax.process_index() == 0: + print("WARNING: max_grad_norm is unused in PSGD Kron optimizer") + + def _optimizer(learning_rate) -> optax.GradientTransformation: + components = [] + components.append( + scale_by_kron_for_levanter( + b1=self.beta1, + preconditioner_update_probability=precond_update_prob_schedule( + min_prob=self.preconditioner_update_probability, + flat_start=self.update_prob_flat_start, + ), + max_size_triangular=self.max_size_triangular, + min_ndim_triangular=self.min_ndim_triangular, + memory_save_mode=self.memory_save_mode, + preconditioner_lr=self.preconditioner_lr, + preconditioner_init_scale=self.preconditioner_init_scale, + mu_dtype=self.mu_dtype, + precond_dtype=self.precond_dtype, + precond_update_precision=self.precond_update_precision, + precond_grads_precision=self.precond_grads_precision, + lax_map_scanned_layers=self.lax_map_scanned_layers, + lax_map_batch_size=self.lax_map_batch_size, + ) + ) + # PSGD's output should be RMS=1.0, so we can clip at 1.1 in case of incoming + # gradient spike. This is better than clipping incoming grads because that would + # get rid of valuable information for the preconditioner. + components.append(optax.clip_by_block_rms(1.1)) + if self.weight_decay > 0: + components.append( + optax.add_decayed_weights( + self.weight_decay, self.build_weight_decay_mask() + ) + ) + components.append(optax.scale_by_learning_rate(learning_rate)) + return optax.chain(*components) + + return optax.inject_hyperparams(_optimizer)( + learning_rate=self.lr_scheduler(num_train_steps) + ) + + +"""PSGD Kron""" +from typing import Any, List, Optional, Union, Callable +from functools import partial +import string +import numpy as np + +import jax +from jax import vmap +import jax.numpy as jnp +from jax.sharding import PartitionSpec +from jax.lax import with_sharding_constraint +from optax import tree_utils as otu +from optax._src import base, transform +from optax._src.numerics import safe_int32_increment +from optax._src.utils import canonicalize_dtype +from optax._src.combine import chain + +import haliax as hax + + +def precond_update_prob_schedule( + max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500 +): + """Anneal preconditioner update probability during beginning of training. + + PSGD benefits from more preconditioner updates at the beginning of training, + but once the preconditioner is learned the update probability can drop low. + + This schedule is an exponential anneal with a flat start. Default settings keep + update probability at 1.0 for 250 steps then exponentially anneal down to + `min_prob` by 4000 steps. Default settings work well for most models and + training regimes. + """ + + def _schedule(n): + """Exponential anneal with flat start.""" + return jnp.minimum( + jnp.maximum(max_prob * jnp.exp(-decay * (n - flat_start)), min_prob), + max_prob, + ) + + return _schedule + + +def scale_by_kron_for_levanter( + b1: float = 0.9, + preconditioner_update_probability: Union[ + float, Callable[[int], float] + ] = precond_update_prob_schedule(), + max_size_triangular: int = 8192, + min_ndim_triangular: int = 2, + memory_save_mode: Optional[str] = None, + momentum_into_precond_update: bool = True, + preconditioner_lr: float = 0.1, + preconditioner_init_scale: float = 1.0, + mu_dtype: Optional[Union[str, jnp.dtype]] = None, + precond_dtype: Optional[Union[str, jnp.dtype]] = None, + precond_update_precision: Optional[str] = "tensorfloat32", + precond_grads_precision: Optional[str] = None, + lax_map_scanned_layers: bool = False, + lax_map_batch_size: int = 8, +) -> base.GradientTransformation: + """ + Implements PSGD Kron from https://github.com/lixilinx/psgd_torch. + + A simple version of scale_by_kron that is focused on working only within levanter + with FSDP sharding for preconditioners. + + Args: + b1: float, momentum parameter. + preconditioner_update_probability: float, probability of updating the + preconditioner. Default anneals from 1.0 to 0.03 by 4000 steps. + max_size_triangular: int, max size for dim's preconditioner to be triangular. + min_ndim_triangular: int, minimum number of dimensions a layer needs to have + triangular preconditioners. + memory_save_mode: optional str, None, 'one_diag', or 'all_diag', None is default + to set all preconditioners to be triangular, 'one_diag' sets the largest + or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners + to be diagonal. + momentum_into_precond_update: bool, whether to send momentum into preconditioner + update instead of raw gradients. + preconditioner_lr: float, learning rate for preconditioner. + preconditioner_init_scale: float, scale for preconditioner initialization. + mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. + Defaults to the same dtype as the parameters. + precond_dtype: optional str or jnp.dtype, dtype of the preconditioner. + precond_update_precision: str, precision for matmul during preconditioner update, + 'bfloat16', 'tensorfloat32', 'float32'. + precond_grads_precision: str, precision for matmul during preconditioning grads, + 'bfloat16', 'tensorfloat32', 'float32'. + lax_map_scanned_layers: bool, whether to use lax.map for scanned layers + instead of vmap. Useful to save memory with large models. + lax_map_batch_size: int, batch size for lax.map, see JAX docs for more info. + + Returns: + optax.GradientTransformation + """ + mu_dtype = canonicalize_dtype(mu_dtype) + precond_dtype = canonicalize_dtype(precond_dtype) + + def map_fn(do_map, fn, *args): + """Maybe map a fn along first axis.""" + if do_map: + if lax_map_scanned_layers: + return jax.lax.map( + lambda xs: fn(*xs), + xs=args, + batch_size=lax_map_batch_size if lax_map_batch_size > 1 else None, + ) + else: + return vmap(fn)(*args) + else: + return fn(*args) + + def init_fn(params): + def fsdp_size(): + mesh = hax.partitioning._get_mesh() + fsdp_axis_name = hax.partitioning.ResourceAxis.DATA + fsdp_axis = mesh.axis_names.index(fsdp_axis_name) + fsdp_size = mesh.devices.shape[fsdp_axis] + return fsdp_size + + # grab scanned layers and params sharding + scanned_layers_ = jax.tree.map( + lambda x: ( + jax.tree.map(lambda _: True, x, is_leaf=lambda x: isinstance(x, jax.Array)) + if isinstance(x, hax.nn.Stacked) + else jax.tree.map(lambda _: False, x, is_leaf=lambda x: isinstance(x, jax.Array)) + ), + params, + is_leaf=lambda x: isinstance(x, hax.nn.Stacked), + ) + params_sharding_ = hax.partitioning.infer_resource_partitions(params) + params_sharding_ = jax.tree.map(lambda x: x.spec, params_sharding_) + + params, params_structure = jax.tree.flatten(params, is_leaf=lambda x: isinstance(x, jax.Array)) + scanned_layers_ = params_structure.flatten_up_to(scanned_layers_) + params_sharding_ = jax.tree.leaves(params_sharding_, is_leaf=lambda x: isinstance(x, PartitionSpec)) + # print(f"kron params: {jax.tree.map(lambda x: x.shape, params)}") + # print(f"kron scanned_layers_: {scanned_layers_}") + # print(f"kron params_sharding_: {params_sharding_}") + + # momentum + mu = None + if b1 > 0: + mu = jax.tree.map(lambda x: jnp.zeros_like(x, dtype=mu_dtype), params) + mu = with_sharding_constraint(mu, params_sharding_) + + # preconditioners + Qs = [ + _init_Q_exprs( + t[0] if s else t, + preconditioner_init_scale, + max_size_triangular, + min_ndim_triangular, + memory_save_mode, + precond_dtype, + )[0] + for t, s in zip(jax.tree.leaves(params), jax.tree.leaves(scanned_layers_)) + ] + # broadcast for scanned layers + def shard_q(q, s): + q_shape_no_s = q.shape[int(s):] + if len(q_shape_no_s) > 1 and q_shape_no_s[0] % fsdp_size() == 0: + return with_sharding_constraint( + q, PartitionSpec(None, 'data') if s else PartitionSpec('data') + ) + else: + return with_sharding_constraint(q, PartitionSpec(None)) + + Qs = [ + ( + jax.tree.map( + lambda d: shard_q(jnp.repeat(jnp.expand_dims(d, 0), t.shape[0], axis=0), s), q + ) + if s + else q + ) + for q, t, s in zip( + Qs, jax.tree.leaves(params), jax.tree.leaves(scanned_layers_) + ) + ] + + # Calculate sizes for nu (preconditioner) and mu (momentum) + Qs_n_elements = sum([q.size for q in jax.tree.leaves(Qs)]) + Qs_size_MB = sum( + [q.size * q.dtype.itemsize / (2**20) for q in jax.tree.leaves(Qs)] + ) + if jax.process_index() == 0: + print( + f"PSGD Preconditioners size: {Qs_n_elements} elements, " + f"{Qs_size_MB:.2f} MB" + ) + if mu is not None: + mu_n_elements = sum([p.size for p in jax.tree.leaves(mu)]) + mu_size_MB = sum( + [p.size * p.dtype.itemsize / (2**20) for p in jax.tree.leaves(mu)] + ) + if jax.process_index() == 0: + print( + f"PSGD Momentum size: {mu_n_elements} elements, {mu_size_MB:.2f} MB" + ) + + # initial state + return dict( + count=jnp.zeros([], jnp.int32), + mu=mu, + Qs_preconditioners=Qs, + update_counter=jnp.zeros([], jnp.int32), + ) + + def update_fn(updates: base.Updates, state: dict, params: base.Params = None): + del params + count_inc = safe_int32_increment(state["count"]) + key = jax.random.fold_in(jax.random.PRNGKey(42), state["count"]) + + def fsdp_size(): + mesh = hax.partitioning._get_mesh() + fsdp_axis_name = hax.partitioning.ResourceAxis.DATA + fsdp_axis = mesh.axis_names.index(fsdp_axis_name) + fsdp_size = mesh.devices.shape[fsdp_axis] + return fsdp_size + + # grab scanned layers and params sharding + scanned_layers_ = jax.tree.map( + lambda x: ( + jax.tree.map(lambda _: True, x, is_leaf=lambda x: isinstance(x, jax.Array)) + if isinstance(x, hax.nn.Stacked) + else jax.tree.map(lambda _: False, x, is_leaf=lambda x: isinstance(x, jax.Array)) + ), + updates, + is_leaf=lambda x: isinstance(x, hax.nn.Stacked), + ) + params_sharding_ = hax.partitioning.infer_resource_partitions(updates) + params_sharding_ = jax.tree.map(lambda x: x.spec, params_sharding_) + + updates, grads_structure = jax.tree.flatten(updates, is_leaf=lambda x: isinstance(x, jax.Array)) + scanned_layers_ = grads_structure.flatten_up_to(scanned_layers_) + params_sharding_ = jax.tree.leaves(params_sharding_, is_leaf=lambda x: isinstance(x, PartitionSpec)) + Qs = state["Qs_preconditioners"] + # print(f"kron updates: {jax.tree.map(lambda x: x.shape, updates)}") + # print(f"kron scanned_layers_: {scanned_layers_}") + # print(f"kron params_sharding_: {params_sharding_}") + + update_prob_in = preconditioner_update_probability + if isinstance(preconditioner_update_probability, Callable): + update_prob_in = preconditioner_update_probability(count_inc) + + # momentum + mu = None + momentum_updates = updates + if state["mu"] is not None: + mu = otu.tree_update_moment(updates, state["mu"], b1, 1) + mu = with_sharding_constraint(mu, params_sharding_) + momentum_updates = otu.tree_bias_correction(mu, b1, count_inc) + momentum_updates = with_sharding_constraint(momentum_updates, params_sharding_) + + # get einsum expressions + expressions = [ + _init_Q_exprs( + t[0] if s else t, + preconditioner_init_scale, + max_size_triangular, + min_ndim_triangular, + memory_save_mode, + precond_dtype, + existing_Q=jax.tree.map(lambda d: d[0], Q) if s else Q, + ) + for t, s, Q in zip(updates, scanned_layers_, Qs) + ] + + # qs sharding + def get_q_sharding(q, s): + q_shape_no_s = q.shape[int(s):] + if len(q_shape_no_s) > 1 and q_shape_no_s[0] % fsdp_size() == 0: + return PartitionSpec(None, 'data') if s else PartitionSpec('data') + else: + return PartitionSpec(None) + + qs_sharding_ = [[get_q_sharding(q, s)for q in Q] for Q, s in zip(Qs, scanned_layers_)] + + # maybe update preconditioner + def update_preconditioner(key, Qs): + with jax.default_matmul_precision(precond_update_precision): + if momentum_into_precond_update: + precond_updates_in = momentum_updates + else: + precond_updates_in = updates + + # balance preconditioners about every 100 updates + def balance_Qs(Qs: List[List[jax.Array]]): + def _balance_Q(Q: List[jax.Array]): + norms = jnp.array( + [jnp.max(jnp.abs(q)) for q in Q], dtype=jnp.float32 + ) + gmean = jnp.prod(norms) ** (1 / len(norms)) + to_mul = gmean / norms + return [q * x.astype(q.dtype) for q, x in zip(Q, to_mul)] + + return [ + map_fn(s, _balance_Q, Q) if len(Q) > 1 else Q + for Q, s in zip(Qs, scanned_layers_) + ] + + key, subkey = jax.random.split(key) + do_balances = jax.random.uniform(subkey) < 0.01 + Qs = jax.lax.cond(do_balances, balance_Qs, lambda qs: qs, Qs) + Qs = with_sharding_constraint(Qs, qs_sharding_) + + # create random vectors + key, subkey = jax.random.split(key) + Vs_keys = jax.random.split(subkey, len(precond_updates_in)) + Vs = [ + jax.random.normal(k, shape=g.shape, dtype=g.dtype) + for k, g in zip(Vs_keys, precond_updates_in) + ] + Vs = with_sharding_constraint(Vs, params_sharding_) + + # damp based on machine precision (f32 probably enough) + damp_eps = jnp.sqrt(jnp.finfo(jnp.float32).eps) + precond_updates_in = jax.tree.map( + lambda g, v: g + damp_eps.astype(g.dtype) * jnp.mean(jnp.abs(g)) * v, + precond_updates_in, + Vs, + ) + + # form conjB + conjBs = [ + map_fn(s, _conjB, Q, g, v) + for s, Q, g, v in zip(scanned_layers_, Qs, precond_updates_in, Vs) + ] + conjBs = with_sharding_constraint(conjBs, params_sharding_) + + # update Qs + new_Qs = [ + map_fn( + s, + partial( + _update_precond, exprs=exprs, precond_lr=preconditioner_lr + ), + Q, + g, + conjb, + ) + for s, exprs, Q, g, conjb in zip( + scanned_layers_, expressions, Qs, precond_updates_in, conjBs + ) + ] + new_Qs = with_sharding_constraint(new_Qs, qs_sharding_) + + new_Qs = otu.tree_cast(new_Qs, precond_dtype) + return new_Qs + + # update preconditioner deterministically + update_counter_inc = safe_int32_increment(state["update_counter"]) + do_update = update_counter_inc >= 1 / update_prob_in + update_counter_inc = jnp.where(do_update, 0, update_counter_inc) + key, subkey = jax.random.split(key) + Qs = jax.lax.cond(do_update, update_preconditioner, lambda _, qs: qs, subkey, Qs) + Qs = with_sharding_constraint(Qs, qs_sharding_) + + # precondition gradients + with jax.default_matmul_precision(precond_grads_precision): + precond_gs = [ + map_fn(s, partial(_precond_grad, exprs=exprs), Q, g) + for s, exprs, Q, g in zip( + scanned_layers_, expressions, Qs, momentum_updates + ) + ] + precond_gs = with_sharding_constraint(precond_gs, params_sharding_) + + # unflatten pytrees + precond_gs = grads_structure.unflatten(precond_gs) + + # dtypes and new state + mu = otu.tree_cast(mu, mu_dtype) + Qs = otu.tree_cast(Qs, precond_dtype) + state = dict( + count=count_inc, + mu=mu, + Qs_preconditioners=Qs, + update_counter=update_counter_inc, + ) + + return precond_gs, state + + return base.GradientTransformation(init_fn, update_fn) + + +def kron( + learning_rate: Union[float, Callable[[int], float]] = 0.001, + b1: float = 0.9, + weight_decay: float = 0.0, + weight_decay_mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, + preconditioner_update_probability: Union[ + float, Callable[[int], float] + ] = precond_update_prob_schedule(), + max_size_triangular: int = 8192, + min_ndim_triangular: int = 2, + memory_save_mode: Optional[str] = None, + momentum_into_precond_update: bool = True, + preconditioner_lr: float = 0.1, + preconditioner_init_scale: float = 1.0, + mu_dtype: Optional[Union[str, jnp.dtype]] = None, + precond_dtype: Optional[Union[str, jnp.dtype]] = None, + precond_update_precision: Optional[str] = "tensorfloat32", + precond_grads_precision: Optional[str] = None, + lax_map_scanned_layers: bool = False, + lax_map_batch_size: int = 8, +) -> base.GradientTransformation: + """ + Implements PSGD Kron from https://github.com/lixilinx/psgd_torch. + + Args: + learning_rate: float or callable, learning rate. + b1: float, momentum parameter. + weight_decay: float, weight decay. + weight_decay_mask: optional Any or callable, pytree of bool same structure + as params with weight decay applied to True elements. + preconditioner_update_probability: float, probability of updating the + preconditioner. Default anneals from 1.0 to 0.03 by 4000 steps. + max_size_triangular: int, max size for dim's preconditioner to be triangular. + min_ndim_triangular: int, minimum number of dimensions a layer needs to have + triangular preconditioners. + memory_save_mode: optional str, None, 'one_diag', or 'all_diag', None is default + to set all preconditioners to be triangular. 'one_diag' sets only the largest + or last dim in a layer to be diagonal, and 'all_diag' sets all preconditioners + to be diagonal. + momentum_into_precond_update: bool, whether to send momentum into preconditioner + update instead of raw gradients. + preconditioner_lr: float, learning rate for preconditioner. + preconditioner_init_scale: float, scale for preconditioner initialization. + mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. + Defaults to the same dtype as the parameters. + precond_dtype: optional str or jnp.dtype, dtype of the preconditioner. + precond_update_precision: str, precision for matmul during preconditioner update, + 'bfloat16', 'tensorfloat32', 'float32'. + precond_grads_precision: str, precision for matmul during preconditioning grads, + 'bfloat16', 'tensorfloat32', 'float32'. + lax_map_scanned_layers: bool, whether to use lax.map for scanned layers + instead of vmap. Useful to save memory with large models. + lax_map_batch_size: int, batch size for lax.map, see JAX docs for more info. + + Returns: + optax.GradientTransformation + """ + optimizer = [ + scale_by_kron_for_levanter( + b1=b1, + preconditioner_update_probability=preconditioner_update_probability, + max_size_triangular=max_size_triangular, + min_ndim_triangular=min_ndim_triangular, + memory_save_mode=memory_save_mode, + momentum_into_precond_update=momentum_into_precond_update, + preconditioner_lr=preconditioner_lr, + preconditioner_init_scale=preconditioner_init_scale, + mu_dtype=mu_dtype, + precond_dtype=precond_dtype, + precond_update_precision=precond_update_precision, + precond_grads_precision=precond_grads_precision, + lax_map_scanned_layers=lax_map_scanned_layers, + lax_map_batch_size=lax_map_batch_size, + ) + ] + if weight_decay > 0.0: + optimizer.append(transform.add_decayed_weights(weight_decay, weight_decay_mask)) + optimizer.append(transform.scale_by_learning_rate(learning_rate)) + return chain(*optimizer) + + +def _add_tiny(x): + return x + jnp.finfo(x.dtype).tiny + + +def _norm_lower_bound(A: jax.Array): + """Returns a cheap lower bound for the spectral norm of A. + + Numerical results on random matrices with a wide range of distributions and + sizes suggest, norm(A) <= sqrt(2) * norm_lower_bound(A). Looks to be a very + tight lower bound. + """ + max_abs = jnp.max(jnp.abs(A)) + + def calc(A): + A = A / max_abs + aa = A * A + + aa_sum0 = jnp.sum(aa, axis=0) + aa_sum1 = jnp.sum(aa, axis=1) + i = jnp.argmax(aa_sum0, 0) + j = jnp.argmax(aa_sum1, 0) + value0 = jax.lax.dynamic_index_in_dim(aa_sum0, i, 0, keepdims=False) + value1 = jax.lax.dynamic_index_in_dim(aa_sum1, j, 0, keepdims=False) + + def gt_branch(): + x = jax.lax.dynamic_index_in_dim(A, i, 1, keepdims=False) + x = x @ A + return max_abs * jnp.linalg.norm((x / jnp.linalg.norm(x)) @ A.T) + + def le_branch(): + x = jax.lax.dynamic_index_in_dim(A, j, 0, keepdims=False) + x = A @ x + return max_abs * jnp.linalg.norm(A.T @ (x / jnp.linalg.norm(x))) + + return jax.lax.cond(value0 > value1, gt_branch, le_branch) + + def no_calc(_): + return max_abs + + return jax.lax.cond(max_abs > 0, calc, no_calc, A) + + +def _init_Q_exprs( + t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype, existing_Q=None +): + """For a scalar or tensor `t`, we initialize its preconditioner `Q` and + reusable contraction expressions for updating `Q` and preconditioning gradient. + """ + letters = string.ascii_lowercase + string.ascii_uppercase + + shape = t.shape + if len(shape) == 0: # scalar + Q = ( + [scale * jnp.ones_like(t, dtype=dtype)] + if existing_Q is None + else existing_Q + ) + exprA = ",->" + exprGs = [",->"] + exprP = ",,->" + else: # tensor + if len(shape) > 13: + raise ValueError( + f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!" + ) + + scale = scale ** (1 / len(shape)) + + if memory_save_mode is None: + dim_diag = [False for _ in shape] + elif memory_save_mode == "one_diag": + rev_sorted_dims = np.argsort(shape)[::-1] + dim_diag = [False for _ in shape] + dim_diag[rev_sorted_dims[0]] = True + elif memory_save_mode == "all_diag": + dim_diag = [True for _ in shape] + else: + raise ValueError( + f"Invalid memory_save_mode: {memory_save_mode}, must be one of " + "[None, 'one_diag', 'all_diag']" + ) + + Q = [] if existing_Q is None else existing_Q + piece1A, piece2A, piece3A = ([], "", "") + exprGs = [] + piece1P, piece2P, piece3P, piece4P = ([], [], "", "") + for i, (size, dim_d) in enumerate(zip(shape, dim_diag)): + if ( + size == 1 + or size > max_size + or len(shape) < min_ndim_triangular + or dim_d + ): + # use diagonal matrix as preconditioner for this dim + if existing_Q is None: + Q.append(scale * jnp.ones(size, dtype=dtype)) + + piece1A.append(letters[i]) + piece2A = piece2A + letters[i] + piece3A = piece3A + letters[i] + + piece1 = "".join( + [ + (letters[i + 13] if j == i else letters[j]) + for j in range(len(shape)) + ] + ) + exprGs.append(piece1 + "," + piece1 + "->" + letters[i + 13]) + + piece1P.append(letters[i + 13]) + piece2P.append(letters[i + 13]) + piece3P = piece3P + letters[i + 13] + piece4P = piece4P + letters[i + 13] + else: + # use triangular matrix as preconditioner for this dim + if existing_Q is None: + def fsdp_size(): + mesh = hax.partitioning._get_mesh() + fsdp_axis_name = hax.partitioning.ResourceAxis.DATA + fsdp_axis = mesh.axis_names.index(fsdp_axis_name) + fsdp_size = mesh.devices.shape[fsdp_axis] + return fsdp_size + + new_q = scale * jnp.eye(size, dtype=dtype) + if new_q.shape[0] % fsdp_size() == 0: + new_q = with_sharding_constraint(new_q, PartitionSpec('data')) + Q.append(new_q) + + piece1A.append(letters[i] + letters[i + 13]) + piece2A = piece2A + letters[i + 13] + piece3A = piece3A + letters[i] + + piece1 = "".join( + [ + (letters[i + 13] if j == i else letters[j]) + for j in range(len(shape)) + ] + ) + piece2 = "".join( + [ + (letters[i + 26] if j == i else letters[j]) + for j in range(len(shape)) + ] + ) + exprGs.append( + piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26] + ) + + a, b, c = (letters[i], letters[i + 13], letters[i + 26]) + piece1P.append(a + b) + piece2P.append(a + c) + piece3P = piece3P + c + piece4P = piece4P + b + + exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A + exprP = ( + ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P + ) + + exprGs = tuple(exprGs) + if existing_Q is not None: + return exprA, exprGs, exprP + return [Q, (exprA, exprGs, exprP)] + + +def _solve_triangular_right(X, A): + """Compute X @ inv(A). + + A triangular solve has roughly the same complexity as a matmul. + """ + X_ndim = X.ndim + if X_ndim < 2: + X = X[None, :] + + dtype_in = jnp.promote_types(A.dtype, X.dtype) + A, X = A.astype(dtype_in), X.astype(dtype_in) + leading_dims = 0 + if X.ndim > 2: + leading_dims = X.ndim - 2 + solve_fn = partial(jax.lax.linalg.triangular_solve, left_side=False, lower=False) + for _ in range(leading_dims): + solve_fn = vmap(solve_fn, in_axes=(None, 0)) + solution = solve_fn(A, X) + + if X_ndim < 2: + return solution[0] + return solution + + +def _conjB(Q, G, V): + """Compute conjB.""" + order = G.ndim + p = list(range(order)) + conjB = jnp.transpose(V, p[1:] + p[:1]) + for i, q in enumerate(Q): + conjB = conjB / q if q.ndim < 2 else _solve_triangular_right(conjB, q) + if i < order - 1: + conjB = jnp.swapaxes(conjB, i, order - 1) + return conjB + + +def _update_precond(Q, G, conjB, exprs, precond_lr): + """Compute A and update Q.""" + exprA, exprGs, _ = exprs + + A = jnp.einsum(exprA, *Q, G) + + def _update_single_q(i, q): + term1 = jnp.einsum(exprGs[i], A, A) + term2 = jnp.einsum(exprGs[i], conjB, conjB) + + if q.ndim < 2: + q -= ( + precond_lr + / _add_tiny(jnp.max(jnp.abs(term1 + term2))) + * (term1 - term2) + * q + ) + else: + # main place I've found so far that needs specific sharding constraint is + # here on terms with transposed q sharding + term1 = with_sharding_constraint(term1, PartitionSpec(None, 'data')) + term2 = with_sharding_constraint(term2, PartitionSpec(None, 'data')) + + q -= ( + precond_lr + / _add_tiny(_norm_lower_bound(term1 + term2)) + * jnp.triu(term1 - term2) + @ q + ) + return q + + return [_update_single_q(i, q) for i, q in enumerate(Q)] + + +def _precond_grad(Q, G, exprs): + """Precondition gradient G with preconditioner Q.""" + exprP = exprs[-1] + return jnp.einsum(exprP, *Q, *Q, G) diff --git a/src/levanter/optim/mars.py b/src/levanter/optim/mars.py new file mode 100644 index 000000000..c117a27f6 --- /dev/null +++ b/src/levanter/optim/mars.py @@ -0,0 +1,135 @@ +import abc +import functools +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional, TypeVar + +import equinox as eqx +import jax +import jaxtyping +import optax +from jax import numpy as jnp +from jax.random import PRNGKey +from jaxtyping import PRNGKeyArray + +import levanter.tracker +from levanter.optim.config import HessianOptConfig, OptimizerConfig +from levanter.optim.util import hvp, tree_gaussian_like +from levanter.utils.jax_utils import parameter_count, tree_filter_like + + +@OptimizerConfig.register_subclass("mars") +@dataclass +class MarsConfig(OptimizerConfig): + weight_decay: float = 0.1 + beta1: float = 0.95 + # cf https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.optim.DecoupledAdamW.html + # https://x.com/giffmana/status/1692641748445438301 + beta2: float = 0.99 + gamma: float = 0.025 + epsilon: float = 1e-8 + max_grad_norm: Optional[float] = 1.0 + haps: Optional[list[int]] = None + schedule_list: Optional[list[str]] = None + + def build(self, num_train_steps): + """Creates the optimizer""" + # indirection makes it work with optax.inject_hyperparams so we can log the learning rate + def _optimizer(learning_rate): + components = [] + + + components.append(scale_by_mars(self.beta1, self.beta2, self.gamma, self.epsilon, max_grad_norm = self.max_grad_norm)) + + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + + # - learning rate for descent + components.append(optax.scale(-learning_rate)) + + optimizer = optax.chain(*components) + + return optimizer + + return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps)) + +from optax import tree_utils as otu +import jax +import jax.numpy as jnp +from jax import jit + + +import chex + +class ScaleByMarsState(NamedTuple): + """State for the Mars algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + mog: optax.Updates + + +def scale_by_mars( + b1: float = 0.9, + b2: float = 0.999, + gamma: float = 0.05, + eps: float = 1e-8, + eps_root: float = 0.0, + max_grad_norm: float = 0.0, + mu_dtype: Optional[Any] = None +) -> optax.GradientTransformation: + r"""Rescale updates according to the MARS algorithm. + https://arxiv.org/abs/2411.10438 + See :func:optax.adam for more details. + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + gamma: control the scale of variance reduction + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + mu_dtype: Optional dtype to be used for the first order accumulator; if + None then the dtype is inferred from params and updates. + Returns: + A :class:optax.GradientTransformation object. + """ + + mu_dtype = jax.dtypes.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment + nu = otu.tree_zeros_like(params) # Second moment + mog = otu.tree_zeros_like(params, dtype=mu_dtype) # gradient from + return ScaleByMarsState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, mog = mog) + + def update_fn(updates, state, params=None): + c = jax.tree.map( + lambda og, g: None if g is None else g + (gamma * b1 / (1 - b1)) * (g - og), + state.mog, + updates, + is_leaf=lambda x: x is None, + ) + if max_grad_norm: + g_norm = optax.global_norm(c) + scale = jnp.minimum(1.0, max_grad_norm / (g_norm + 1e-6)) + c = jax.tree_map(lambda g: None if g is None else g * scale, + c, + is_leaf=lambda x: x is None + ) + mu = otu.tree_update_moment(c, state.mu, b1, 1) + nu = otu.tree_update_moment_per_elem_norm(c, state.nu, b2, 2) + count_inc = optax.safe_increment(state.count) + mu_hat = otu.tree_bias_correction(mu, b1, count_inc) + # Dozat 2016 https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ + # Algorithm 2 further multiplies Adam's standard nu_hat by b2. It is + # unclear why. Other Nadam implementations also omit the extra b2 factor. + nu_hat = otu.tree_bias_correction(nu, b2, count_inc) + adam_updates = jax.tree.map( + lambda m, v: None if m is None else m / (jnp.sqrt(v + eps_root) + eps), + mu_hat, + nu_hat, + is_leaf=lambda x: x is None, + ) + mu = otu.tree_cast(mu, mu_dtype) + return adam_updates, ScaleByMarsState(count=count_inc, mu=mu, nu=nu, mog = updates) + return optax.GradientTransformation(init_fn, update_fn) \ No newline at end of file diff --git a/src/levanter/optim/muon.py b/src/levanter/optim/muon.py new file mode 100644 index 000000000..81d9d9dec --- /dev/null +++ b/src/levanter/optim/muon.py @@ -0,0 +1,169 @@ +import dataclasses +from dataclasses import dataclass +from typing import NamedTuple + +import chex +import jax +import jax.numpy as jnp +import optax +from optax import tree_utils as otu + +import haliax +from haliax.nn import Linear + +from levanter.optim.config import OptimizerConfig +from levanter.optim.util import map_flattened_linear_layers +from levanter.utils.jax_utils import leaf_key_paths + + +@OptimizerConfig.register_subclass("muon") +@dataclass +class MuonConfig(OptimizerConfig): + """ + Muon optimizer configuration: Momentum Orthogonalized by Newton-Schulz. + """ + + lr: float = 0.02 + muon_to_adam_lr: float = 0.18 # Scaling factor between AdamW and Muon learning rates + momentum: float = 0.95 + nesterov: bool = True + backend_steps: int = 10 # Number of steps for Newton-Schulz orthogonalization + weight_decay: float = 0.0 + beta1: float = 0.9 + beta2: float = 0.95 + epsilon: float = 1e-8 + max_grad_norm: float = 1.0 + # adam_modules: Optional[list[str] | str] = None + # """A regex or a list of strings to identify where to mask weight. + # For nano-GPT, this field can be set as `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`""" + # default_adam_mask: Optional[bool] = None + # """Whether to apply a default reasonable weight decay to modules not explicitly masked. None means it will if + # no weight_decay_modules are set. False means it will not. True means it will regardless of weight_decay_modules.""" + + def build(self, num_train_steps): + """ + Creates the optimizer. + """ + learning_rate_schedule = self.lr_scheduler(num_train_steps) + + def optimizer(learning_rate): + adam_lr = learning_rate * self.muon_to_adam_lr + + def muon_transform(): + components = [] + # Muon seems incompatible with gradient clipping, need to investigate + # if self.max_grad_norm: + # components.append(optax.clip_by_global_norm(self.max_grad_norm)) + components.append(scale_with_muon(self.momentum, self.nesterov, self.backend_steps)) + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + components.append(optax.scale(-learning_rate)) + optimizer = optax.chain(*components) + return optimizer + + def adamw_transform(): + components = [] + if self.max_grad_norm: + components.append(optax.clip_by_global_norm(self.max_grad_norm)) + components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon)) + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + components.append(optax.scale(-adam_lr)) + optimizer = optax.chain(*components) + return optimizer + + transformations = { + "muon": muon_transform(), + "adamw": adamw_transform(), + } + + return optax.multi_transform(transformations, self.create_mask) + + return optax.inject_hyperparams(optimizer)(learning_rate=learning_rate_schedule) + + def create_mask(self, params): + """ + Creates a mask that labels parameters as 'muon' or 'adamw' based on their + dimensionality and module path, using AdamW for Embedding and lm_head parameters. + """ + paths = leaf_key_paths(params) + + def mask_fn(param, path): + path_str = ".".join(path) if isinstance(path, (list, tuple)) else str(path) + if "Embedding" in path_str or "lm_head" in path_str: + return "adamw" + elif isinstance(param, Linear): + # muon for linear layers + return dataclasses.replace(param, weight="muon", bias="adamw" if param.bias is not None else None) + else: + return "adamw" + + return jax.tree_util.tree_map(mask_fn, params, paths, is_leaf=lambda x: isinstance(x, Linear)) + + +class ScaleByMuonState(NamedTuple): + """State for the Mars algorithm.""" + + momentum_buffer: optax.Updates + + +def scale_with_muon(momentum=0.95, nesterov=True, steps=5): + def init_fn(params): + momentum_buffer = otu.tree_zeros_like(params) # First moment + return ScaleByMuonState(momentum_buffer=momentum_buffer) + + def update_fn(updates, state, params=None): + buf = state.momentum_buffer + buf = jax.tree.map( + lambda m, g: None if g is None else momentum * m + g, + buf, + updates, + is_leaf=lambda x: x is None, + ) + if nesterov: + updates = jax.tree.map( + lambda m, g: None if g is None else momentum * m + g, + buf, + updates, + is_leaf=lambda x: x is None, + ) + else: + updates = buf + + def transform_linear_layer(layer: haliax.nn.Linear): + assert layer.weight.ndim == 2 + + updated_weight_array = zeropower_via_newtonschulz5(layer.weight.array, steps=steps) + + scale = jnp.sqrt(jnp.maximum(1, updated_weight_array.shape[0] / updated_weight_array.shape[1])) + updated_weight_array *= scale + + updated_weight = dataclasses.replace(layer.weight, array=updated_weight_array) + + return dataclasses.replace(layer, weight=updated_weight) # type: ignore + + updates = map_flattened_linear_layers(transform_linear_layer, updates) + + return updates, ScaleByMuonState(momentum_buffer=buf) + + return optax.GradientTransformation(init_fn, update_fn) + + +def zeropower_via_newtonschulz5(X, steps=10, eps=1e-7): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. + """ + chex.assert_rank(X, 2) + a, b, c = (3.4445, -4.7750, 2.0315) + X /= jnp.linalg.norm(X) + eps # Ensure top singular value <= 1 + transpose = False + if X.shape[0] > X.shape[1]: + X = X.T + transpose = True + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if transpose: + X = X.T + return X diff --git a/src/levanter/optim/util.py b/src/levanter/optim/util.py index 7fd3a41df..5eecf91d1 100644 --- a/src/levanter/optim/util.py +++ b/src/levanter/optim/util.py @@ -1,5 +1,12 @@ +from typing import Callable + import equinox as eqx import jax +from jaxtyping import PyTree + +import haliax +import haliax as hax +from haliax.tree_util import scan_aware_tree_map from levanter.utils.jax_utils import is_inexact_arrayish @@ -21,3 +28,53 @@ def tree_gaussian_like(key, tree): g = jax.tree_util.tree_unflatten(structure, g) return g + + +def map_flattened_linear_layers( + f: Callable[[hax.nn.Linear], hax.nn.Linear], + params: PyTree, + *, + or_else: Callable | None = None, + is_leaf: Callable | None = None, +): + """ + Apply a function to all Linear layers in a PyTree, flattening articulated input/output dims into single dims, then + unflattening them back into the original structure. This method also takes care of vmapping over scan layers. + + The linear layers will be passed to the function `f` and the result will be used to replace the original linear layer. + The linear layers passed to `f` will be flattened into 2D (named) arrays, and the result will be unflattened back into the original shape. + The bias term, if any, will be passed as a 1D named arrays. + The weight array will not be None, but the bias array may be None. + + Args: + f: The function to apply to each Linear layer + params: The PyTree of parameters + or_else: optional function to apply to non-Linear leaves + is_leaf: optional function to determine if a node is a leaf. Linears will always be considered leaves. + + Returns: + The PyTree with the function applied to all Linear layers and the structure preserved otherwise. + returned linear layers will be unfattened back to their original shape. + + """ + + if is_leaf is None: + is_leaf = lambda x: isinstance(x, hax.nn.Linear) + else: + _is_leaf = is_leaf + is_leaf = lambda x: isinstance(x, hax.nn.Linear) or _is_leaf(x) + + def map_fn(p): + if isinstance(p, hax.nn.Linear): + if p.weight is None: + return p + return f(p) + elif or_else is not None: + return or_else(p) + else: + return p + + flattened_linear = haliax.state_dict.flatten_linear_layers(params) + flattened_linear = scan_aware_tree_map(map_fn, flattened_linear, is_leaf=is_leaf) + + return haliax.state_dict.unflatten_linear_layers(params, flattened_linear)