diff --git a/config/llama2_100M_mars.yaml b/config/llama2_100M_mars.yaml new file mode 100644 index 000000000..2c062d816 --- /dev/null +++ b/config/llama2_100M_mars.yaml @@ -0,0 +1,34 @@ +data: !include data/dclm_gpt_neo.yaml +model: + type: llama + seq_len: 4096 + hidden_dim: 768 + intermediate_dim: 3072 + num_layers: 12 + num_heads: 12 + num_kv_heads: 12 +trainer: + tracker: + project: "levanter" + tags: ["pile", "llama"] + mp: p=f32,c=bfloat16 + model_axis_size: 1 + checkpointer: + keep: + - every: 1000 + save_interval: 30m + + + train_batch_size: 1024 + per_device_parallelism: 4 # set for v3 TPU + per_device_eval_parallelism: 4 # set a larger batch size for eval + num_train_steps: 50001 +optimizer: + learning_rate: 4E-3 # set low for fine-tuning + weight_decay: 0.1 + min_lr_ratio: 0.0 + warmup: 2000 + cooldown: 0.4 + lr_schedule: constant + gamma: 0.025 + type: mars diff --git a/config/llama2_100M_muon.yaml b/config/llama2_100M_muon.yaml new file mode 100644 index 000000000..3f8194465 --- /dev/null +++ b/config/llama2_100M_muon.yaml @@ -0,0 +1,34 @@ +data: !include data/dclm_gpt_neo.yaml +model: + type: llama + seq_len: 4096 + hidden_dim: 768 + intermediate_dim: 3072 + num_layers: 12 + num_heads: 12 + num_kv_heads: 12 +trainer: + tracker: + project: "levanter" + tags: ["pile", "llama"] + mp: p=f32,c=bfloat16 + model_axis_size: 1 + checkpointer: + keep: + - every: 1000 + save_interval: 30m + + + train_batch_size: 1024 + per_device_parallelism: 4 # set for v3 TPU + per_device_eval_parallelism: 4 # set a larger batch size for eval + num_train_steps: 50001 +optimizer: + learning_rate: 2E-2 # set low for fine-tuning + weight_decay: 0 + warmup: 0 + cooldown: 0.1 + lr_schedule: constant + min_lr_ratio: 0.0 + max_grad_norm: 0.0 + type: muon diff --git a/error_loading_model.sh b/error_loading_model.sh new file mode 100644 index 000000000..8a555be70 --- /dev/null +++ b/error_loading_model.sh @@ -0,0 +1,10 @@ +eval $(ssh-agent -s) +bash infra/babysit-tpu-vm.sh muon-debug -z us-central2-b -t v4-128 --preemptible -- \ +WANDB_API_KEY=[WANDB_API_KEY] \ +bash levanter/infra/run.sh python \ +levanter/src/levanter/main/train_lm.py \ +--config_path levanter/config/llama2_100M_muon.yaml \ +--trainer.checkpointer.base_path gs://marin-us-central2/scratch/kaiyue/checkpoints/muon/llama2_100M_constant \ +--optimizer.type muon \ +--trainer.num_train_steps 10000 \ +--trainer.load_checkpoint_path gs://marin-us-central2/scratch/kaiyue/checkpoints/muon/llama2_100M_constant/tjo9vxfb/step-4000 diff --git a/src/levanter/optim/__init__.py b/src/levanter/optim/__init__.py index 7dec2ebb4..2cd5ad781 100644 --- a/src/levanter/optim/__init__.py +++ b/src/levanter/optim/__init__.py @@ -5,3 +5,11 @@ scale_by_sophia_g, scale_by_sophia_h, ) +from .muon import ( + MuonConfig, + ScaleByMuonState +) +from .mars import ( + MarsConfig, + ScaleByMarsState +) \ No newline at end of file diff --git a/src/levanter/optim/mars.py b/src/levanter/optim/mars.py new file mode 100644 index 000000000..c117a27f6 --- /dev/null +++ b/src/levanter/optim/mars.py @@ -0,0 +1,135 @@ +import abc +import functools +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional, TypeVar + +import equinox as eqx +import jax +import jaxtyping +import optax +from jax import numpy as jnp +from jax.random import PRNGKey +from jaxtyping import PRNGKeyArray + +import levanter.tracker +from levanter.optim.config import HessianOptConfig, OptimizerConfig +from levanter.optim.util import hvp, tree_gaussian_like +from levanter.utils.jax_utils import parameter_count, tree_filter_like + + +@OptimizerConfig.register_subclass("mars") +@dataclass +class MarsConfig(OptimizerConfig): + weight_decay: float = 0.1 + beta1: float = 0.95 + # cf https://docs.mosaicml.com/projects/composer/en/latest/api_reference/generated/composer.optim.DecoupledAdamW.html + # https://x.com/giffmana/status/1692641748445438301 + beta2: float = 0.99 + gamma: float = 0.025 + epsilon: float = 1e-8 + max_grad_norm: Optional[float] = 1.0 + haps: Optional[list[int]] = None + schedule_list: Optional[list[str]] = None + + def build(self, num_train_steps): + """Creates the optimizer""" + # indirection makes it work with optax.inject_hyperparams so we can log the learning rate + def _optimizer(learning_rate): + components = [] + + + components.append(scale_by_mars(self.beta1, self.beta2, self.gamma, self.epsilon, max_grad_norm = self.max_grad_norm)) + + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + + # - learning rate for descent + components.append(optax.scale(-learning_rate)) + + optimizer = optax.chain(*components) + + return optimizer + + return optax.inject_hyperparams(_optimizer)(learning_rate=self.lr_scheduler(num_train_steps)) + +from optax import tree_utils as otu +import jax +import jax.numpy as jnp +from jax import jit + + +import chex + +class ScaleByMarsState(NamedTuple): + """State for the Mars algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: optax.Updates + nu: optax.Updates + mog: optax.Updates + + +def scale_by_mars( + b1: float = 0.9, + b2: float = 0.999, + gamma: float = 0.05, + eps: float = 1e-8, + eps_root: float = 0.0, + max_grad_norm: float = 0.0, + mu_dtype: Optional[Any] = None +) -> optax.GradientTransformation: + r"""Rescale updates according to the MARS algorithm. + https://arxiv.org/abs/2411.10438 + See :func:optax.adam for more details. + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + gamma: control the scale of variance reduction + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + mu_dtype: Optional dtype to be used for the first order accumulator; if + None then the dtype is inferred from params and updates. + Returns: + A :class:optax.GradientTransformation object. + """ + + mu_dtype = jax.dtypes.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment + nu = otu.tree_zeros_like(params) # Second moment + mog = otu.tree_zeros_like(params, dtype=mu_dtype) # gradient from + return ScaleByMarsState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, mog = mog) + + def update_fn(updates, state, params=None): + c = jax.tree.map( + lambda og, g: None if g is None else g + (gamma * b1 / (1 - b1)) * (g - og), + state.mog, + updates, + is_leaf=lambda x: x is None, + ) + if max_grad_norm: + g_norm = optax.global_norm(c) + scale = jnp.minimum(1.0, max_grad_norm / (g_norm + 1e-6)) + c = jax.tree_map(lambda g: None if g is None else g * scale, + c, + is_leaf=lambda x: x is None + ) + mu = otu.tree_update_moment(c, state.mu, b1, 1) + nu = otu.tree_update_moment_per_elem_norm(c, state.nu, b2, 2) + count_inc = optax.safe_increment(state.count) + mu_hat = otu.tree_bias_correction(mu, b1, count_inc) + # Dozat 2016 https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ + # Algorithm 2 further multiplies Adam's standard nu_hat by b2. It is + # unclear why. Other Nadam implementations also omit the extra b2 factor. + nu_hat = otu.tree_bias_correction(nu, b2, count_inc) + adam_updates = jax.tree.map( + lambda m, v: None if m is None else m / (jnp.sqrt(v + eps_root) + eps), + mu_hat, + nu_hat, + is_leaf=lambda x: x is None, + ) + mu = otu.tree_cast(mu, mu_dtype) + return adam_updates, ScaleByMarsState(count=count_inc, mu=mu, nu=nu, mog = updates) + return optax.GradientTransformation(init_fn, update_fn) \ No newline at end of file diff --git a/src/levanter/optim/muon.py b/src/levanter/optim/muon.py new file mode 100644 index 000000000..81d9d9dec --- /dev/null +++ b/src/levanter/optim/muon.py @@ -0,0 +1,169 @@ +import dataclasses +from dataclasses import dataclass +from typing import NamedTuple + +import chex +import jax +import jax.numpy as jnp +import optax +from optax import tree_utils as otu + +import haliax +from haliax.nn import Linear + +from levanter.optim.config import OptimizerConfig +from levanter.optim.util import map_flattened_linear_layers +from levanter.utils.jax_utils import leaf_key_paths + + +@OptimizerConfig.register_subclass("muon") +@dataclass +class MuonConfig(OptimizerConfig): + """ + Muon optimizer configuration: Momentum Orthogonalized by Newton-Schulz. + """ + + lr: float = 0.02 + muon_to_adam_lr: float = 0.18 # Scaling factor between AdamW and Muon learning rates + momentum: float = 0.95 + nesterov: bool = True + backend_steps: int = 10 # Number of steps for Newton-Schulz orthogonalization + weight_decay: float = 0.0 + beta1: float = 0.9 + beta2: float = 0.95 + epsilon: float = 1e-8 + max_grad_norm: float = 1.0 + # adam_modules: Optional[list[str] | str] = None + # """A regex or a list of strings to identify where to mask weight. + # For nano-GPT, this field can be set as `r".*attn.*weight|.*mlp.*weight|.*token_embeddings|.*position_embeddings"`""" + # default_adam_mask: Optional[bool] = None + # """Whether to apply a default reasonable weight decay to modules not explicitly masked. None means it will if + # no weight_decay_modules are set. False means it will not. True means it will regardless of weight_decay_modules.""" + + def build(self, num_train_steps): + """ + Creates the optimizer. + """ + learning_rate_schedule = self.lr_scheduler(num_train_steps) + + def optimizer(learning_rate): + adam_lr = learning_rate * self.muon_to_adam_lr + + def muon_transform(): + components = [] + # Muon seems incompatible with gradient clipping, need to investigate + # if self.max_grad_norm: + # components.append(optax.clip_by_global_norm(self.max_grad_norm)) + components.append(scale_with_muon(self.momentum, self.nesterov, self.backend_steps)) + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + components.append(optax.scale(-learning_rate)) + optimizer = optax.chain(*components) + return optimizer + + def adamw_transform(): + components = [] + if self.max_grad_norm: + components.append(optax.clip_by_global_norm(self.max_grad_norm)) + components.append(optax.scale_by_adam(self.beta1, self.beta2, self.epsilon)) + if self.weight_decay > 0: + components.append(optax.add_decayed_weights(self.weight_decay, self.build_weight_decay_mask())) + components.append(optax.scale(-adam_lr)) + optimizer = optax.chain(*components) + return optimizer + + transformations = { + "muon": muon_transform(), + "adamw": adamw_transform(), + } + + return optax.multi_transform(transformations, self.create_mask) + + return optax.inject_hyperparams(optimizer)(learning_rate=learning_rate_schedule) + + def create_mask(self, params): + """ + Creates a mask that labels parameters as 'muon' or 'adamw' based on their + dimensionality and module path, using AdamW for Embedding and lm_head parameters. + """ + paths = leaf_key_paths(params) + + def mask_fn(param, path): + path_str = ".".join(path) if isinstance(path, (list, tuple)) else str(path) + if "Embedding" in path_str or "lm_head" in path_str: + return "adamw" + elif isinstance(param, Linear): + # muon for linear layers + return dataclasses.replace(param, weight="muon", bias="adamw" if param.bias is not None else None) + else: + return "adamw" + + return jax.tree_util.tree_map(mask_fn, params, paths, is_leaf=lambda x: isinstance(x, Linear)) + + +class ScaleByMuonState(NamedTuple): + """State for the Mars algorithm.""" + + momentum_buffer: optax.Updates + + +def scale_with_muon(momentum=0.95, nesterov=True, steps=5): + def init_fn(params): + momentum_buffer = otu.tree_zeros_like(params) # First moment + return ScaleByMuonState(momentum_buffer=momentum_buffer) + + def update_fn(updates, state, params=None): + buf = state.momentum_buffer + buf = jax.tree.map( + lambda m, g: None if g is None else momentum * m + g, + buf, + updates, + is_leaf=lambda x: x is None, + ) + if nesterov: + updates = jax.tree.map( + lambda m, g: None if g is None else momentum * m + g, + buf, + updates, + is_leaf=lambda x: x is None, + ) + else: + updates = buf + + def transform_linear_layer(layer: haliax.nn.Linear): + assert layer.weight.ndim == 2 + + updated_weight_array = zeropower_via_newtonschulz5(layer.weight.array, steps=steps) + + scale = jnp.sqrt(jnp.maximum(1, updated_weight_array.shape[0] / updated_weight_array.shape[1])) + updated_weight_array *= scale + + updated_weight = dataclasses.replace(layer.weight, array=updated_weight_array) + + return dataclasses.replace(layer, weight=updated_weight) # type: ignore + + updates = map_flattened_linear_layers(transform_linear_layer, updates) + + return updates, ScaleByMuonState(momentum_buffer=buf) + + return optax.GradientTransformation(init_fn, update_fn) + + +def zeropower_via_newtonschulz5(X, steps=10, eps=1e-7): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. + """ + chex.assert_rank(X, 2) + a, b, c = (3.4445, -4.7750, 2.0315) + X /= jnp.linalg.norm(X) + eps # Ensure top singular value <= 1 + transpose = False + if X.shape[0] > X.shape[1]: + X = X.T + transpose = True + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if transpose: + X = X.T + return X diff --git a/src/levanter/optim/util.py b/src/levanter/optim/util.py index 7fd3a41df..5eecf91d1 100644 --- a/src/levanter/optim/util.py +++ b/src/levanter/optim/util.py @@ -1,5 +1,12 @@ +from typing import Callable + import equinox as eqx import jax +from jaxtyping import PyTree + +import haliax +import haliax as hax +from haliax.tree_util import scan_aware_tree_map from levanter.utils.jax_utils import is_inexact_arrayish @@ -21,3 +28,53 @@ def tree_gaussian_like(key, tree): g = jax.tree_util.tree_unflatten(structure, g) return g + + +def map_flattened_linear_layers( + f: Callable[[hax.nn.Linear], hax.nn.Linear], + params: PyTree, + *, + or_else: Callable | None = None, + is_leaf: Callable | None = None, +): + """ + Apply a function to all Linear layers in a PyTree, flattening articulated input/output dims into single dims, then + unflattening them back into the original structure. This method also takes care of vmapping over scan layers. + + The linear layers will be passed to the function `f` and the result will be used to replace the original linear layer. + The linear layers passed to `f` will be flattened into 2D (named) arrays, and the result will be unflattened back into the original shape. + The bias term, if any, will be passed as a 1D named arrays. + The weight array will not be None, but the bias array may be None. + + Args: + f: The function to apply to each Linear layer + params: The PyTree of parameters + or_else: optional function to apply to non-Linear leaves + is_leaf: optional function to determine if a node is a leaf. Linears will always be considered leaves. + + Returns: + The PyTree with the function applied to all Linear layers and the structure preserved otherwise. + returned linear layers will be unfattened back to their original shape. + + """ + + if is_leaf is None: + is_leaf = lambda x: isinstance(x, hax.nn.Linear) + else: + _is_leaf = is_leaf + is_leaf = lambda x: isinstance(x, hax.nn.Linear) or _is_leaf(x) + + def map_fn(p): + if isinstance(p, hax.nn.Linear): + if p.weight is None: + return p + return f(p) + elif or_else is not None: + return or_else(p) + else: + return p + + flattened_linear = haliax.state_dict.flatten_linear_layers(params) + flattened_linear = scan_aware_tree_map(map_fn, flattened_linear, is_leaf=is_leaf) + + return haliax.state_dict.unflatten_linear_layers(params, flattened_linear) diff --git a/src/levanter/tensorstore_serialization.py b/src/levanter/tensorstore_serialization.py index 82471e5df..fc9155cd1 100644 --- a/src/levanter/tensorstore_serialization.py +++ b/src/levanter/tensorstore_serialization.py @@ -95,7 +95,7 @@ def path_from_key_path(key_path): def _sharding_from_leaf(leaf, axis_mapping, mesh) -> Optional[jax.sharding.Sharding]: if is_named_array(leaf): - if leaf.array is None: + if not is_jax_array_like(leaf.array): return None return hax.partitioning.sharding_for_axis(leaf.axes, axis_mapping, mesh) elif hasattr(leaf, "sharding") and getattr(leaf, "sharding") is not None: @@ -140,11 +140,11 @@ def tree_deserialize_leaves_tensorstore( manager = array_ser.GlobalAsyncCheckpointManager() shardings: PyTree[Optional[Sharding]] = jtu.tree_map( - partial(_sharding_from_leaf, axis_mapping=axis_mapping, mesh=mesh), pytree, is_leaf=is_named_array + partial(_sharding_from_leaf, axis_mapping=axis_mapping, mesh=mesh), pytree, is_leaf=_is_named_or_none ) # TODO: support ShapeDtypeStructs that are not NamedArrays - leaf_key_paths = jax_utils.leaf_key_paths(shardings, is_leaf=is_named_array) + leaf_key_paths = jax_utils.leaf_key_paths(shardings, is_leaf=_is_named_or_none) paths = _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths) paths = jtu.tree_leaves(paths, is_leaf=lambda x: x is None) @@ -157,6 +157,8 @@ def tree_deserialize_leaves_tensorstore( real_leaves = [x for x in shardings_leaves if x is not None] real_paths = [paths[i] for i in real_indices] + assert len(real_leaves) == len(real_paths), f"{len(real_leaves)} != {len(real_paths)}" + deser_leaves = manager.deserialize_with_paths(shardings=real_leaves, paths=real_paths) # now we need to recreate the original structure