diff --git a/CHANGELOG.md b/CHANGELOG.md index 28e6c02..0d27a2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +## [v0.4.1](https://github.com/stockeh/mlx-optimizers/releases/tag/v0.4.1) - 2024-11-22 + +### Fixed +- optim: ADOPT update to use clipping (arXiv v2) + ## [v0.4.0](https://github.com/stockeh/mlx-optimizers/releases/tag/v0.4.0) - 2024-11-20 ### Added diff --git a/docs/src/_static/media/compare-cifar10-blank.png b/docs/src/_static/media/compare-cifar10-blank.png index 87a3ae9..b13b841 100644 Binary files a/docs/src/_static/media/compare-cifar10-blank.png and b/docs/src/_static/media/compare-cifar10-blank.png differ diff --git a/docs/src/_static/media/rosenbrock_ADOPT.png b/docs/src/_static/media/rosenbrock_ADOPT.png index c1380c0..e84bd59 100644 Binary files a/docs/src/_static/media/rosenbrock_ADOPT.png and b/docs/src/_static/media/rosenbrock_ADOPT.png differ diff --git a/examples/networks/compare.py b/examples/networks/compare.py index 498af3f..43d1eb5 100644 --- a/examples/networks/compare.py +++ b/examples/networks/compare.py @@ -59,8 +59,7 @@ def get_optimizers(args): { "learning_rate": get_cosine_schedule(3e-3, 1e-6, n_warmup, decay_steps), "weight_decay": weight_decay, - "learning_rate_1d": get_cosine_schedule(3e-3, 1e-6, n_warmup, decay_steps), - "weight_decay_1d": weight_decay, + "optimize_1d": True, }, ), ] diff --git a/examples/rosenbrock.py b/examples/rosenbrock.py index 1023de9..ef14fb9 100644 --- a/examples/rosenbrock.py +++ b/examples/rosenbrock.py @@ -117,7 +117,7 @@ def execute_experiments(optimizers, objective, func, plot_func, initial_state): (optim.QHAdam, 0, 0.5, {}), (optim.DiffGrad, 0, 0.5, {}), (optim.MADGRAD, 0, 0.5, {}), - (optim.ADOPT, 0, 0.25, {}), + (optim.ADOPT, 0, 0.5, {}), (optim.Lamb, 0, 0.25, {}), (optim.Muon, 0, 0.2, {"alternate_optimizer": AdamW(learning_rate=0.0842)}), # fixed lr (optim.Shampoo, 0, 2, {}), diff --git a/mlx_optimizers/adopt.py b/mlx_optimizers/adopt.py index 1117be6..f1d6fbe 100644 --- a/mlx_optimizers/adopt.py +++ b/mlx_optimizers/adopt.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union import mlx.core as mx from mlx.optimizers import Optimizer @@ -9,10 +9,10 @@ class ADOPT(Optimizer): .. math:: - v_0 &= g_0^2, m_1 = g_1 / \max{\sqrt{v_0}, \epsilon} \\ - \theta_{t} &= \theta_{t-1} - \eta m_{t-1} \\ + m_0 &= \mathbf{0}, \quad v_0 = g_0^2 \\ + m_t &= \beta_1 m_{t-1} + (1 - \beta_1) \text{clip} \left( \frac{g_t}{\text{max}(\sqrt{v_{t-1}, \epsilon})}, c_t\right) \\ + \theta_{t} &= \theta_{t-1} - \eta m_t \\ v_{t} &= \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 - m_{t+1} &= \beta_1 m_{t} + (1 - \beta_1) (g_{t+1} / \max{\sqrt{v_t}, \epsilon}) [1] Taniguchi, Shohei, et al., 2024. ADOPT: Modified Adam Can Converge with Any :math:`\beta_2` with the Optimal Rate. NeurIPS 2024. @@ -25,6 +25,9 @@ class ADOPT(Optimizer): :math:`(\beta_1, \beta_2)` used for computing running averages of the gradient and its square. Default: ``(0.9, 0.9999)`` weight_decay (float, optional): The weight decay. Default: ``0.0`` + decouple (bool, optional): AdamW if ``True``. Default: ``False`` + clip_lambda (callable, optional): The clipping function :math:`c_t` for the + gradient. Set to ``None`` for previous behavior. Default: ``step**0.25`` eps (float, optional): The term :math:`\epsilon` added to the denominator to improve numerical stability. Default: ``1e-6`` @@ -36,6 +39,8 @@ def __init__( learning_rate: Union[float, Callable[[mx.array], mx.array]], betas: List[float] = [0.9, 0.9999], weight_decay: float = 0.0, + decouple: bool = False, + clip_lambda: Optional[Callable[[mx.array], mx.array]] = lambda step: mx.power(step, 0.25), eps: float = 1e-6, ): super().__init__() @@ -43,37 +48,45 @@ def __init__( self._maybe_schedule("learning_rate", learning_rate) self.betas = betas self.weight_decay = weight_decay + self.decouple = decouple + self.clip_lambda = clip_lambda self.eps = eps def init_single(self, parameter: mx.array, state: dict): """Initialize optimizer state""" - state["m"] = [] + state["m"] = mx.zeros_like(parameter) state["c"] = 0 def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): """Performs a single optimization step, updating :math:`m` and :math:`v`""" - state["c"] += 1 - if self.weight_decay != 0: - gradient = gradient + self.weight_decay * parameter + weight_decay = self.weight_decay + decouple = self.decouple - if state["c"] == 1: + if weight_decay != 0 and not decouple: + gradient = gradient + weight_decay * parameter + + if state["c"] == 0: state["v"] = mx.square(gradient) + state["c"] += 1 return parameter lr = self.learning_rate.astype(gradient.dtype) + + if weight_decay != 0 and decouple: + parameter = parameter - lr * weight_decay * parameter + b1, b2 = self.betas - eps = self.eps m = state["m"] v = state["v"] - denom = mx.maximum(mx.sqrt(v), eps) - - if state["c"] == 2: - m = gradient / denom - else: - m = b1 * m + (1 - b1) * gradient / denom + denom = mx.maximum(mx.sqrt(v), self.eps) + normed_grad = gradient / denom + if self.clip_lambda is not None: + clip = self.clip_lambda(self.step - 1) + normed_grad = mx.clip(normed_grad, -clip, clip) + m = b1 * m + (1 - b1) * normed_grad parameter = parameter - lr * m state["v"] = b2 * v + (1 - b2) * mx.square(gradient) diff --git a/mlx_optimizers/version.py b/mlx_optimizers/version.py index eca97dc..008b370 100644 --- a/mlx_optimizers/version.py +++ b/mlx_optimizers/version.py @@ -2,7 +2,7 @@ _MINOR = "4" # On main and in a nightly release the patch should be one ahead of the last # released build. -_PATCH = "0" +_PATCH = "1" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. _SUFFIX = "" diff --git a/tests/test_basic.py b/tests/test_basic.py index 0d2049b..01b1fe0 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -36,7 +36,7 @@ def beale(xy): (optim.QHAdam, {"learning_rate": 0.25}, 150), (optim.DiffGrad, {"learning_rate": 0.3}, 150), (optim.MADGRAD, {"learning_rate": 0.03}, 150), - (optim.ADOPT, {"learning_rate": 0.17}, 150), + (optim.ADOPT, {"learning_rate": 0.5}, 250), ( optim.Muon, # using alternate for ndim < 2 { diff --git a/tests/test_neuralnet.py b/tests/test_neuralnet.py index e3ab01d..a2e6e34 100644 --- a/tests/test_neuralnet.py +++ b/tests/test_neuralnet.py @@ -49,6 +49,7 @@ def generate_moons(n_samples: int = 100, noise: float = 0.2): ), (optim.MADGRAD, {"learning_rate": 0.01}, 50), (optim.ADOPT, {"learning_rate": 0.03}, 50), + (optim.ADOPT, {"learning_rate": 0.03, "clip_lambda": None}, 50), (optim.Lamb, {"learning_rate": 0.03}, 50), (optim.Shampoo, {"learning_rate": 0.03}, 50), (optim.Kron, {"learning_rate": 0.03}, 50),