Skip to content

Commit

Permalink
small changes/moving to remote
Browse files Browse the repository at this point in the history
  • Loading branch information
evanatyourservice committed Dec 22, 2024
1 parent ee747c0 commit 59f2c10
Showing 1 changed file with 11 additions and 26 deletions.
37 changes: 11 additions & 26 deletions src/levanter/optim/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class KronConfig(OptimizerConfig):
weight_decay: Weight decay coefficient.
max_grad_norm: Optional gradient norm clipping value.
normalize_grads: Whether to normalize the incoming gradients to unit norm layer-wise.
Can help with stability.
Can help with stability but likely not necessary in this scenario.
preconditioner_update_probability: Final probability of updating the preconditioner. Default
is 0.05 (update every 20 steps). The `precond_update_prob_schedule` holds probability at
1.0 for `update_prob_flat_start` steps before annealing exponentially down to this
Expand Down Expand Up @@ -50,20 +50,14 @@ class KronConfig(OptimizerConfig):
lax_map_batch_size: Batch size for lax.map, see JAX docs for more info.
merge_small_dims: Whether to merge small dimensions to improve preconditioner efficiency.
target_merged_dim_size: Target size of merged dimensions.
partition_grads_into_blocks: Whether to partition grads into chunks of size block_size
for efficiency.
block_size: Block size to use for partitioning grads.
params_sharding: Pytree same structure as params of jax.sharding.PartitionSpec.
preconditioner_sharding: PartitionSpec for preconditioner matrices. Best practice is to
shard first dimension across fsdp-like mesh axis, or largest/most common axis in params.
Example: PartitionSpec('fsdp') or PartitionSpec('fsdp', 'tp').
"""
# some of these are changed from kron defaults to better suit levanter
beta1: float = 0.9
weight_decay: float = 0.1
max_grad_norm: Optional[float] = 0.0
normalize_grads: bool = True
preconditioner_update_probability: float = 0.03
normalize_grads: bool = False
preconditioner_update_probability: float = 0.05
update_prob_flat_start: int = 1000
max_size_triangular: int = 25000
min_ndim_triangular: int = 2
Expand All @@ -72,27 +66,19 @@ class KronConfig(OptimizerConfig):
preconditioner_init_scale: float = 1.0
mu_dtype: Optional[Union[str, jnp.dtype]] = None
precond_dtype: Optional[Union[str, jnp.dtype]] = None
precond_update_precision: Optional[str] = "float32"
precond_update_precision: Optional[str] = "tensorfloat32"
precond_grads_precision: Optional[str] = None
scanned_layers: Optional[optax.Params] = None
lax_map_scanned_layers: bool = False
lax_map_batch_size: int = 8
merge_small_dims: bool = True
target_merged_dim_size: int = 8192
partition_grads_into_blocks: bool = True
block_size: int = 256
params_sharding: Optional[Any] = None
preconditioner_sharding: Optional[tuple[str | None, str | None]] = None

def build(self, num_train_steps):
"""Creates the optimizer."""

def _optimizer(learning_rate) -> optax.GradientTransformation:
precond_partition_spec = (
PartitionSpec(*self.preconditioner_sharding)
if self.preconditioner_sharding is not None
else None
)
components = []
if self.max_grad_norm and not self.normalize_grads:
components.append(optax.clip_by_global_norm(self.max_grad_norm))
Expand All @@ -116,14 +102,15 @@ def _optimizer(learning_rate) -> optax.GradientTransformation:
scanned_layers=self.scanned_layers,
lax_map_scanned_layers=self.lax_map_scanned_layers,
lax_map_batch_size=self.lax_map_batch_size,
# merge_small_dims=self.merge_small_dims,
# target_merged_dim_size=self.target_merged_dim_size,
# partition_grads_into_blocks=self.partition_grads_into_blocks,
# block_size=self.block_size,
# params_sharding=self.params_sharding,
# preconditioner_sharding=precond_partition_spec,
merge_small_dims=self.merge_small_dims,
target_merged_dim_size=self.target_merged_dim_size,
params_sharding=self.params_sharding,
)
)
# PSGD's output should be RMS=1.0, so we can clip at 1.1 in case of
# gradient spike. This is better than clipping incoming grads because this
# gets rid of information for the preconditioner.
components.append(optax.clip_by_block_rms(1.1))
if self.weight_decay > 0:
components.append(
optax.add_decayed_weights(
Expand All @@ -143,11 +130,9 @@ def _optimizer(learning_rate) -> optax.GradientTransformation:
import string
import numpy as np

import chex
import jax
from jax import vmap
import jax.numpy as jnp
import flax.linen as nn
from optax import tree_utils as otu
from optax._src import base, transform
from optax._src.numerics import safe_int32_increment
Expand Down

0 comments on commit 59f2c10

Please sign in to comment.