Skip to content

Commit

Permalink
expose momentum into precond, trust region
Browse files Browse the repository at this point in the history
  • Loading branch information
evanatyourservice committed Nov 10, 2024
1 parent 8dbd645 commit 39d4620
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
23 changes: 9 additions & 14 deletions psgd_jax/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def scale_by_kron(
max_size_triangular: int = 8192,
min_ndim_triangular: int = 2,
memory_save_mode: Optional[str] = None,
momentum_into_precond_update: bool = True,
mu_dtype: Optional[Union[str, jnp.dtype]] = None,
precond_dtype: Optional[Union[str, jnp.dtype]] = None,
precond_update_precision: Optional[str] = "tensorfloat32",
precond_grads_precision: Optional[str] = None,
scanned_layers: Optional[base.Params] = None,
lax_map_scanned_layers: bool = False,
lax_map_batch_size: int = 8,
trust_region_scale: float = 1.5,
) -> base.GradientTransformationExtraArgs:
"""
Implements PSGD Kron from https://github.com/lixilinx/psgd_torch.
Expand All @@ -70,6 +70,8 @@ def scale_by_kron(
to set all preconditioners to be triangular, 'one_diag' sets the largest
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners
to be diagonal.
momentum_into_precond_update: bool, whether to send momentum into preconditioner
update instead of raw gradients.
mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator.
Defaults to the same dtype as the parameters.
precond_dtype: optional str or jnp.dtype, dtype of the preconditioner.
Expand All @@ -82,19 +84,14 @@ def scale_by_kron(
lax_map_scanned_layers: bool, whether to use lax.map for scanned layers
instead of vmap. Useful to save memory with large models.
lax_map_batch_size: int, batch size for lax.map, see JAX docs for more info.
trust_region_scale: float, trust region on preconditioned grads. Normally this
doesn't need to be changed but if things seem unstable you can try reducing
this to 1.5.
Returns:
optax.GradientTransformationExtraArgs
"""
mu_dtype = canonicalize_dtype(mu_dtype)
precond_dtype = canonicalize_dtype(precond_dtype)

preconditioner_lr = 0.1
preconditioner_init_scale = 1.0
momentum_before_precond_update = True

def map_fn(do_map, fn, *args):
"""Maybe map a fn along first axis."""
Expand Down Expand Up @@ -229,7 +226,7 @@ def update_fn(updates: base.Updates, state: dict, params: base.Params = None):
# maybe update preconditioner
def update_preconditioner(key, Qs):
with jax.default_matmul_precision(precond_update_precision):
if momentum_before_precond_update:
if momentum_into_precond_update:
precond_updates_in = momentum_updates
else:
precond_updates_in = updates
Expand Down Expand Up @@ -307,8 +304,7 @@ def _balance_Q(Q: List[jax.Array]):
jnp.abs(x) + 1
) + 0.9 * jnp.tanh(x)
precond_gs = jax.tree.map(
lambda x: trust_region_fn(x / trust_region_scale) * trust_region_scale,
precond_gs,
lambda x: jnp.clip(trust_region_fn(x / 1.5) * 1.5, -2, 2), precond_gs
)

# box preconditioned grads
Expand Down Expand Up @@ -342,14 +338,14 @@ def kron(
max_size_triangular: int = 8192,
min_ndim_triangular: int = 2,
memory_save_mode: Optional[str] = None,
momentum_into_precond_update: bool = True,
mu_dtype: Optional[Union[str, jnp.dtype]] = None,
precond_dtype: Optional[Union[str, jnp.dtype]] = None,
precond_update_precision: Optional[str] = "tensorfloat32",
precond_grads_precision: Optional[str] = None,
scanned_layers: Optional[base.Params] = None,
lax_map_scanned_layers: bool = False,
lax_map_batch_size: int = 8,
trust_region_scale: float = 1.5,
) -> base.GradientTransformationExtraArgs:
"""
Implements PSGD Kron from https://github.com/lixilinx/psgd_torch.
Expand All @@ -369,6 +365,8 @@ def kron(
to set all preconditioners to be triangular. 'one_diag' sets only the largest
or last dim in a layer to be diagonal, and 'all_diag' sets all preconditioners
to be diagonal.
momentum_into_precond_update: bool, whether to send momentum into preconditioner
update instead of raw gradients.
mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator.
Defaults to the same dtype as the parameters.
precond_dtype: optional str or jnp.dtype, dtype of the preconditioner.
Expand All @@ -381,9 +379,6 @@ def kron(
lax_map_scanned_layers: bool, whether to use lax.map for scanned layers
instead of vmap. Useful to save memory with large models.
lax_map_batch_size: int, batch size for lax.map, see JAX docs for more info.
trust_region_scale: float, trust region on preconditioned grads. Normally this
doesn't need to be changed but if things seem unstable you can try reducing
this to 1.5.
Returns:
optax.GradientTransformationExtraArgs
Expand All @@ -395,14 +390,14 @@ def kron(
max_size_triangular=max_size_triangular,
min_ndim_triangular=min_ndim_triangular,
memory_save_mode=memory_save_mode,
momentum_into_precond_update=momentum_into_precond_update,
mu_dtype=mu_dtype,
precond_dtype=precond_dtype,
precond_update_precision=precond_update_precision,
precond_grads_precision=precond_grads_precision,
scanned_layers=scanned_layers,
lax_map_scanned_layers=lax_map_scanned_layers,
lax_map_batch_size=lax_map_batch_size,
trust_region_scale=trust_region_scale,
)
]
if weight_decay > 0.0:
Expand Down
7 changes: 6 additions & 1 deletion psgd_jax/psgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,12 @@ def main():
elif precond_type == "kron":
del kwargs["precond_lr"]
del kwargs["update_global_norm_clip"]
optimizer = partial(kron, memory_save_mode="one_diag", **kwargs)
optimizer = partial(
kron,
memory_save_mode=None,
momentum_into_precond_update=False,
**kwargs,
)
else:
optimizer = None

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "psgd-jax"
version = "0.2.6"
version = "0.2.7"
description = "An implementation of PSGD optimizer in JAX."
readme = { file = "README.md", content-type = "text/markdown" }
license = { file = "LICENSE" }
Expand Down

0 comments on commit 39d4620

Please sign in to comment.