Skip to content

Commit

Permalink
Merge pull request #48 from stockeh/main
Browse files Browse the repository at this point in the history
docs update for 0.4.1
  • Loading branch information
stockeh authored Nov 22, 2024
2 parents 63ed71e + c92923f commit 8be0142
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 21 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## [v0.4.1](https://github.com/stockeh/mlx-optimizers/releases/tag/v0.4.1) - 2024-11-22

### Fixed
- optim: ADOPT update to use clipping (arXiv v2)

## [v0.4.0](https://github.com/stockeh/mlx-optimizers/releases/tag/v0.4.0) - 2024-11-20

### Added
Expand Down
Binary file modified docs/src/_static/media/compare-cifar10-blank.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/_static/media/rosenbrock_ADOPT.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions examples/networks/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def get_optimizers(args):
{
"learning_rate": get_cosine_schedule(3e-3, 1e-6, n_warmup, decay_steps),
"weight_decay": weight_decay,
"learning_rate_1d": get_cosine_schedule(3e-3, 1e-6, n_warmup, decay_steps),
"weight_decay_1d": weight_decay,
"optimize_1d": True,
},
),
]
Expand Down
2 changes: 1 addition & 1 deletion examples/rosenbrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def execute_experiments(optimizers, objective, func, plot_func, initial_state):
(optim.QHAdam, 0, 0.5, {}),
(optim.DiffGrad, 0, 0.5, {}),
(optim.MADGRAD, 0, 0.5, {}),
(optim.ADOPT, 0, 0.25, {}),
(optim.ADOPT, 0, 0.5, {}),
(optim.Lamb, 0, 0.25, {}),
(optim.Muon, 0, 0.2, {"alternate_optimizer": AdamW(learning_rate=0.0842)}), # fixed lr
(optim.Shampoo, 0, 2, {}),
Expand Down
45 changes: 29 additions & 16 deletions mlx_optimizers/adopt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Union
from typing import Callable, List, Optional, Union

import mlx.core as mx
from mlx.optimizers import Optimizer
Expand All @@ -9,10 +9,10 @@ class ADOPT(Optimizer):
.. math::
v_0 &= g_0^2, m_1 = g_1 / \max{\sqrt{v_0}, \epsilon} \\
\theta_{t} &= \theta_{t-1} - \eta m_{t-1} \\
m_0 &= \mathbf{0}, \quad v_0 = g_0^2 \\
m_t &= \beta_1 m_{t-1} + (1 - \beta_1) \text{clip} \left( \frac{g_t}{\text{max}(\sqrt{v_{t-1}, \epsilon})}, c_t\right) \\
\theta_{t} &= \theta_{t-1} - \eta m_t \\
v_{t} &= \beta_2 v_{t-1} + (1 - \beta_2) g_t^2
m_{t+1} &= \beta_1 m_{t} + (1 - \beta_1) (g_{t+1} / \max{\sqrt{v_t}, \epsilon})
[1] Taniguchi, Shohei, et al., 2024. ADOPT: Modified Adam Can
Converge with Any :math:`\beta_2` with the Optimal Rate. NeurIPS 2024.
Expand All @@ -25,6 +25,9 @@ class ADOPT(Optimizer):
:math:`(\beta_1, \beta_2)` used for computing running averages of the
gradient and its square. Default: ``(0.9, 0.9999)``
weight_decay (float, optional): The weight decay. Default: ``0.0``
decouple (bool, optional): AdamW if ``True``. Default: ``False``
clip_lambda (callable, optional): The clipping function :math:`c_t` for the
gradient. Set to ``None`` for previous behavior. Default: ``step**0.25``
eps (float, optional): The term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-6``
Expand All @@ -36,44 +39,54 @@ def __init__(
learning_rate: Union[float, Callable[[mx.array], mx.array]],
betas: List[float] = [0.9, 0.9999],
weight_decay: float = 0.0,
decouple: bool = False,
clip_lambda: Optional[Callable[[mx.array], mx.array]] = lambda step: mx.power(step, 0.25),
eps: float = 1e-6,
):
super().__init__()

self._maybe_schedule("learning_rate", learning_rate)
self.betas = betas
self.weight_decay = weight_decay
self.decouple = decouple
self.clip_lambda = clip_lambda
self.eps = eps

def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["m"] = []
state["m"] = mx.zeros_like(parameter)
state["c"] = 0

def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs a single optimization step, updating :math:`m` and :math:`v`"""
state["c"] += 1

if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
weight_decay = self.weight_decay
decouple = self.decouple

if state["c"] == 1:
if weight_decay != 0 and not decouple:
gradient = gradient + weight_decay * parameter

if state["c"] == 0:
state["v"] = mx.square(gradient)
state["c"] += 1
return parameter

lr = self.learning_rate.astype(gradient.dtype)

if weight_decay != 0 and decouple:
parameter = parameter - lr * weight_decay * parameter

b1, b2 = self.betas
eps = self.eps

m = state["m"]
v = state["v"]
denom = mx.maximum(mx.sqrt(v), eps)

if state["c"] == 2:
m = gradient / denom
else:
m = b1 * m + (1 - b1) * gradient / denom
denom = mx.maximum(mx.sqrt(v), self.eps)
normed_grad = gradient / denom
if self.clip_lambda is not None:
clip = self.clip_lambda(self.step - 1)
normed_grad = mx.clip(normed_grad, -clip, clip)

m = b1 * m + (1 - b1) * normed_grad
parameter = parameter - lr * m

state["v"] = b2 * v + (1 - b2) * mx.square(gradient)
Expand Down
2 changes: 1 addition & 1 deletion mlx_optimizers/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
_MINOR = "4"
# On main and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "0"
_PATCH = "1"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def beale(xy):
(optim.QHAdam, {"learning_rate": 0.25}, 150),
(optim.DiffGrad, {"learning_rate": 0.3}, 150),
(optim.MADGRAD, {"learning_rate": 0.03}, 150),
(optim.ADOPT, {"learning_rate": 0.17}, 150),
(optim.ADOPT, {"learning_rate": 0.5}, 250),
(
optim.Muon, # using alternate for ndim < 2
{
Expand Down
1 change: 1 addition & 0 deletions tests/test_neuralnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def generate_moons(n_samples: int = 100, noise: float = 0.2):
),
(optim.MADGRAD, {"learning_rate": 0.01}, 50),
(optim.ADOPT, {"learning_rate": 0.03}, 50),
(optim.ADOPT, {"learning_rate": 0.03, "clip_lambda": None}, 50),
(optim.Lamb, {"learning_rate": 0.03}, 50),
(optim.Shampoo, {"learning_rate": 0.03}, 50),
(optim.Kron, {"learning_rate": 0.03}, 50),
Expand Down

0 comments on commit 8be0142

Please sign in to comment.