From 39d46209591792d01d6de5fb6c42edaf07f73aa8 Mon Sep 17 00:00:00 2001 From: Evan Walters Date: Sat, 9 Nov 2024 18:43:57 -0700 Subject: [PATCH] expose momentum into precond, trust region --- psgd_jax/kron.py | 23 +++++++++-------------- psgd_jax/psgd_test.py | 7 ++++++- pyproject.toml | 2 +- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/psgd_jax/kron.py b/psgd_jax/kron.py index 9e4f322..101c9a0 100644 --- a/psgd_jax/kron.py +++ b/psgd_jax/kron.py @@ -47,6 +47,7 @@ def scale_by_kron( max_size_triangular: int = 8192, min_ndim_triangular: int = 2, memory_save_mode: Optional[str] = None, + momentum_into_precond_update: bool = True, mu_dtype: Optional[Union[str, jnp.dtype]] = None, precond_dtype: Optional[Union[str, jnp.dtype]] = None, precond_update_precision: Optional[str] = "tensorfloat32", @@ -54,7 +55,6 @@ def scale_by_kron( scanned_layers: Optional[base.Params] = None, lax_map_scanned_layers: bool = False, lax_map_batch_size: int = 8, - trust_region_scale: float = 1.5, ) -> base.GradientTransformationExtraArgs: """ Implements PSGD Kron from https://github.com/lixilinx/psgd_torch. @@ -70,6 +70,8 @@ def scale_by_kron( to set all preconditioners to be triangular, 'one_diag' sets the largest or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal. + momentum_into_precond_update: bool, whether to send momentum into preconditioner + update instead of raw gradients. mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. Defaults to the same dtype as the parameters. precond_dtype: optional str or jnp.dtype, dtype of the preconditioner. @@ -82,19 +84,14 @@ def scale_by_kron( lax_map_scanned_layers: bool, whether to use lax.map for scanned layers instead of vmap. Useful to save memory with large models. lax_map_batch_size: int, batch size for lax.map, see JAX docs for more info. - trust_region_scale: float, trust region on preconditioned grads. Normally this - doesn't need to be changed but if things seem unstable you can try reducing - this to 1.5. Returns: optax.GradientTransformationExtraArgs """ mu_dtype = canonicalize_dtype(mu_dtype) precond_dtype = canonicalize_dtype(precond_dtype) - preconditioner_lr = 0.1 preconditioner_init_scale = 1.0 - momentum_before_precond_update = True def map_fn(do_map, fn, *args): """Maybe map a fn along first axis.""" @@ -229,7 +226,7 @@ def update_fn(updates: base.Updates, state: dict, params: base.Params = None): # maybe update preconditioner def update_preconditioner(key, Qs): with jax.default_matmul_precision(precond_update_precision): - if momentum_before_precond_update: + if momentum_into_precond_update: precond_updates_in = momentum_updates else: precond_updates_in = updates @@ -307,8 +304,7 @@ def _balance_Q(Q: List[jax.Array]): jnp.abs(x) + 1 ) + 0.9 * jnp.tanh(x) precond_gs = jax.tree.map( - lambda x: trust_region_fn(x / trust_region_scale) * trust_region_scale, - precond_gs, + lambda x: jnp.clip(trust_region_fn(x / 1.5) * 1.5, -2, 2), precond_gs ) # box preconditioned grads @@ -342,6 +338,7 @@ def kron( max_size_triangular: int = 8192, min_ndim_triangular: int = 2, memory_save_mode: Optional[str] = None, + momentum_into_precond_update: bool = True, mu_dtype: Optional[Union[str, jnp.dtype]] = None, precond_dtype: Optional[Union[str, jnp.dtype]] = None, precond_update_precision: Optional[str] = "tensorfloat32", @@ -349,7 +346,6 @@ def kron( scanned_layers: Optional[base.Params] = None, lax_map_scanned_layers: bool = False, lax_map_batch_size: int = 8, - trust_region_scale: float = 1.5, ) -> base.GradientTransformationExtraArgs: """ Implements PSGD Kron from https://github.com/lixilinx/psgd_torch. @@ -369,6 +365,8 @@ def kron( to set all preconditioners to be triangular. 'one_diag' sets only the largest or last dim in a layer to be diagonal, and 'all_diag' sets all preconditioners to be diagonal. + momentum_into_precond_update: bool, whether to send momentum into preconditioner + update instead of raw gradients. mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. Defaults to the same dtype as the parameters. precond_dtype: optional str or jnp.dtype, dtype of the preconditioner. @@ -381,9 +379,6 @@ def kron( lax_map_scanned_layers: bool, whether to use lax.map for scanned layers instead of vmap. Useful to save memory with large models. lax_map_batch_size: int, batch size for lax.map, see JAX docs for more info. - trust_region_scale: float, trust region on preconditioned grads. Normally this - doesn't need to be changed but if things seem unstable you can try reducing - this to 1.5. Returns: optax.GradientTransformationExtraArgs @@ -395,6 +390,7 @@ def kron( max_size_triangular=max_size_triangular, min_ndim_triangular=min_ndim_triangular, memory_save_mode=memory_save_mode, + momentum_into_precond_update=momentum_into_precond_update, mu_dtype=mu_dtype, precond_dtype=precond_dtype, precond_update_precision=precond_update_precision, @@ -402,7 +398,6 @@ def kron( scanned_layers=scanned_layers, lax_map_scanned_layers=lax_map_scanned_layers, lax_map_batch_size=lax_map_batch_size, - trust_region_scale=trust_region_scale, ) ] if weight_decay > 0.0: diff --git a/psgd_jax/psgd_test.py b/psgd_jax/psgd_test.py index 19bebf0..e7610b0 100644 --- a/psgd_jax/psgd_test.py +++ b/psgd_jax/psgd_test.py @@ -177,7 +177,12 @@ def main(): elif precond_type == "kron": del kwargs["precond_lr"] del kwargs["update_global_norm_clip"] - optimizer = partial(kron, memory_save_mode="one_diag", **kwargs) + optimizer = partial( + kron, + memory_save_mode=None, + momentum_into_precond_update=False, + **kwargs, + ) else: optimizer = None diff --git a/pyproject.toml b/pyproject.toml index f741122..6bd3bac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "psgd-jax" -version = "0.2.6" +version = "0.2.7" description = "An implementation of PSGD optimizer in JAX." readme = { file = "README.md", content-type = "text/markdown" } license = { file = "LICENSE" }