Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optim: add shampoo + tests (#32) #34

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added
- optim: Shampoo (general tensor case)

### Fixed
- optim: docstrings for exisiting methods

Expand Down
Binary file added docs/src/_static/media/rosenbrock_Shampoo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/_static/media/rosenbrock_all.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions examples/rosenbrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def execute_experiments(optimizers, objective, func, plot_func, initial_state):
(optim.ADOPT, 0, 0.25, {}),
(optim.Lamb, 0, 0.25, {}),
(optim.Muon, 0, 0.2, {"alternate_optimizer": AdamW(learning_rate=0.0842)}), # fixed lr
(optim.Shampoo, 0, 2, {}),
]
execute_experiments(
optimizers, objective_rosenbrock, rosenbrock, plot_rosenbrock, ROSENBROCK_INITIAL
Expand Down
12 changes: 11 additions & 1 deletion mlx_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,15 @@
from .madgrad import MADGRAD
from .muon import Muon
from .qhadam import QHAdam
from .shampoo import Shampoo

__all__ = ["DiffGrad", "Muon", "QHAdam", "MADGRAD", "ADOPT", "Lamb" "__version__"]
__all__ = [
"ADOPT",
"DiffGrad",
"Lamb",
"MADGRAD",
"Muon",
"QHAdam",
"Shampoo",
"__version__",
]
4 changes: 2 additions & 2 deletions mlx_optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


def zeropower_via_svd(G, steps=None) -> mx.array:
U, S, V = mx.linalg.svd(G, stream=mx.cpu) # type: ignore
return U @ V.T
U, S, Vt = mx.linalg.svd(G, stream=mx.cpu) # type: ignore
return U @ Vt


@mx.compile
Expand Down
109 changes: 109 additions & 0 deletions mlx_optimizers/shampoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Callable, Union

import mlx.core as mx
from mlx.optimizers import Optimizer


def _matrix_power(matrix: mx.array, power: float, eps=1e-16) -> mx.array:
u, s, vt = mx.linalg.svd(matrix, stream=mx.cpu) # type: ignore
# eps: needed to avoid runtime command buffer execution (MLX)
return u @ mx.power(s + eps, power).diag() @ vt


class Shampoo(Optimizer):
r"""Preconditioned Stochastic Tensor Optimization (general tensor case) [1].

.. math::

W_1 &= 0_{n_1 \times \dots \times n_k}; \forall i \in [k]: H_0^i = \epsilon I_{n_i}\\
H_t^i &= H_{t-1}^i + G_t^{(i)}\\
\tilde{G}_t &= \tilde{G}_t \times_i (H_t^i)^{-1/2k}\\
W_{t+1} &= W_t - \eta \tilde{G}_t

[1] Gupta, Vineet, Tomer Koren, and Yoram Singer, 2018. Shampoo: Preconditioned
stochastic tensor optimization. ICML 2018.
https://arxiv.org/abs/1802.09568


Args:
learning_rate (float or callable): learning rate :math:`\eta`.
momentum (float, optional): momentum factor. Default: ``0.00``
weight_decay (float, optional): weight decay factor. Default: ``0.00``
update_freq (int, optional): frequency of updating the preconditioner. Default: ``1``
eps (float, optional): term :math:`\epsilon` added to the
denominator to improve numerical stability. Default: ``1e-6``

..
"""

def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
momentum: float = 0.0,
weight_decay: float = 0.0,
update_freq: int = 1,
eps: float = 1e-4,
):
super().__init__()

self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
self.update_freq = update_freq
self.eps = eps

def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
if self.momentum != 0:
state["buf"] = mx.zeros_like(parameter)
for i, dim in enumerate(parameter.shape):
state[f"precond_{i}"] = self.eps * mx.eye(dim)
state[f"inv_precond_{i}"] = mx.zeros((dim, dim))
state["dim_inds"] = list(range(parameter.ndim))
state["update_step"] = 0

def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs a single optimization step, updating :math:`m` and :math:`v`"""
lr = self.learning_rate.astype(gradient.dtype)
momentum = self.momentum

d = state["dim_inds"]
order = gradient.ndim
original_size = gradient.shape
if momentum != 0:
if state["update_step"] == 0:
state["buf"] = gradient
gradient = (1 - momentum) * gradient + momentum * state["buf"]

if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter

for i, dim in enumerate(gradient.shape):
precond = state[f"precond_{i}"]
inv_precond = state[f"inv_precond_{i}"]

if i != 0:
gradient = gradient.transpose([d[i]] + d[1:i] + [d[0]] + d[i + 1 :])
transpose_size = gradient.shape
gradient = gradient.reshape(dim, -1)

gradient_t = gradient.T
precond = precond + gradient @ gradient_t
if state["update_step"] % self.update_freq == 0:
inv_precond = _matrix_power(precond, -1 / order)

if i == order - 1: # finally
gradient = gradient_t @ inv_precond
gradient = gradient.reshape(original_size)
state[f"precond_{i}"] = precond
state[f"inv_precond_{i}"] = inv_precond
else:
gradient = inv_precond @ gradient
gradient = gradient.reshape(transpose_size)

if momentum != 0:
state["buf"] = gradient
state["update_step"] += 1

gradient = gradient.reshape(original_size)
return parameter - lr * gradient
2 changes: 1 addition & 1 deletion tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def ids(v):
},
800,
),
# TODO: Lamb tests
# TODO: Lamb & Shampoo tests
]


Expand Down
6 changes: 2 additions & 4 deletions tests/test_neuralnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,13 @@ def ids(v):
(optim.DiffGrad, {"learning_rate": 0.01}, 100),
(
optim.Muon,
{
"learning_rate": 0.01,
"alternate_optimizer": AdamW(learning_rate=0.001),
},
{"learning_rate": 0.01, "alternate_optimizer": AdamW(learning_rate=0.001)},
100,
),
(optim.MADGRAD, {"learning_rate": 0.01}, 50),
(optim.ADOPT, {"learning_rate": 0.01}, 50),
(optim.Lamb, {"learning_rate": 0.01}, 50),
(optim.Shampoo, {"learning_rate": 0.03}, 50),
]


Expand Down
Loading