Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-pages #45

Merged
merged 8 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

## [v0.4.0](https://github.com/stockeh/mlx-optimizers/releases/tag/v0.4.0) - 2024-11-20

### Added
- optim: MARS
- optim: common file for repeated ops (e.g., newton_schulz)
- examples: compare optimizers on mnist/cifar10

### Fixed
- optim: ADOPT correctly implemented (exp_avg_sq and g_0)

## [v0.3.0](https://github.com/stockeh/mlx-optimizers/releases/tag/v0.3.0) - 2024-11-18

### Added
Expand Down
32 changes: 32 additions & 0 deletions docs/src/_autosummary/mlx_optimizers.MARS.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
mlx\_optimizers.MARS
====================

.. currentmodule:: mlx_optimizers

.. autoclass:: MARS







.. image:: mlx_optimizers/../../_static/media/rosenbrock_MARS.png
:align: center







.. rubric:: Methods

.. autosummary::

~MARS.__init__
~MARS.apply_single
~MARS.init_single
~MARS.set_last_grad


Binary file added docs/src/_static/media/compare-cifar10-blank.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/_static/media/rosenbrock_ADOPT.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/_static/media/rosenbrock_Lamb.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/src/_static/media/rosenbrock_MARS.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/src/_static/media/rosenbrock_all.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/src/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Optimizers
ADOPT
DiffGrad
Muon
MARS
QHAdam
MADGRAD
Lamb
Expand Down
138 changes: 138 additions & 0 deletions examples/networks/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import argparse

import matplotlib.pyplot as plt
import mlx.core as mx
from datasets import cifar10, mnist
from manager import Manager
from mlx.optimizers import Adam, AdamW, cosine_decay, join_schedules, linear_schedule
from models import Network

import mlx_optimizers as optim

parser = argparse.ArgumentParser(add_help=True)
parser.add_argument(
"--dataset", type=str, default="cifar10", choices=["mnist", "cifar10"], help="dataset to use"
)
parser.add_argument("-b", "--batch_size", type=int, default=128, help="batch size")
parser.add_argument("-e", "--epochs", type=int, default=50, help="number of epochs")
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument("--cpu", action="store_true", help="use cpu only")


def get_cosine_schedule(max_lr, min_lr, n_warmup, decay_steps):
learning_rate = join_schedules(
[linear_schedule(min_lr, max_lr, n_warmup), cosine_decay(max_lr, decay_steps, min_lr)],
[n_warmup],
)
return learning_rate


def get_optimizers(args):
total_steps = 50_000 // args.batch_size * args.epochs
n_warmup = int(total_steps * 0.10) # % of total steps
decay_steps = total_steps - n_warmup
weight_decay = 1e-4
learning_rate = get_cosine_schedule(6e-4, 1e-6, n_warmup, decay_steps)
optimizers = [
(
Adam,
{
"learning_rate": learning_rate,
},
),
(
AdamW,
{
"learning_rate": learning_rate,
"weight_decay": weight_decay,
},
),
(
optim.ADOPT,
{
"learning_rate": learning_rate,
"weight_decay": weight_decay,
},
),
(
optim.MARS,
{
"learning_rate": get_cosine_schedule(3e-3, 1e-6, n_warmup, decay_steps),
"weight_decay": weight_decay,
"learning_rate_1d": get_cosine_schedule(3e-3, 1e-6, n_warmup, decay_steps),
"weight_decay_1d": weight_decay,
},
),
]

return optimizers


def plot_results(results, optimizers, args):
fig, ax = plt.subplots(figsize=(5.5, 3.5))
colors = ["#74add1", "#1730bd", "#1a9850", "#001c01"]

for i, acc in enumerate(results):
ax.plot(range(1, len(acc) + 1), acc, label=optimizers[i][0].__name__, lw=2, color=colors[i])

ax.set_title(f"{args.dataset.upper()} (val)", loc="left")
ax.set_xlabel("Epoch", fontsize="medium")
ax.set_ylabel("Accuracy (%)", fontsize="medium")

ax.legend(ncols=2, columnspacing=0.8, fontsize="medium")
ax.grid(alpha=0.2)

ax.set_ylim(90 if args.dataset == "mnist" else 70)
acc_min, acc_max = ax.get_ylim()
ax.set_yticks(mx.linspace(acc_min, acc_max, 5, dtype=mx.int8))
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

fig.tight_layout()
fig.savefig(
f"../../docs/src/_static/media/compare-{args.dataset}-blank.png",
dpi=300,
bbox_inches="tight",
)
plt.show()


def main(args):
mx.random.seed(args.seed)
if args.dataset == "mnist":
train_data, test_data = mnist(args.batch_size)
elif args.dataset == "cifar10":
train_data, test_data = cifar10(args.batch_size)
else:
raise NotImplementedError(f"{args.dataset=} is not implemented.")
x_shape = next(train_data)["image"].shape
train_data.reset()

model_config = {
"n_inputs": x_shape[1:],
"conv_layers_list": [
{"filters": 32, "kernel_size": 3, "repeat": 2, "batch_norm": True},
{"filters": 64, "kernel_size": 3, "repeat": 2, "batch_norm": True},
{"filters": 128, "kernel_size": 3, "repeat": 2, "batch_norm": True},
],
"n_hiddens_list": [512],
"n_outputs": 10,
"dropout": 0.2,
}

optimizers = get_optimizers(args)

results = []
for optimizer, optimizer_kwargs in optimizers:
mx.random.seed(args.seed)
manager = Manager(Network(**model_config), optimizer(**optimizer_kwargs)) # type: ignore
manager.train(train_data, val=test_data, epochs=args.epochs)
results.append(100 * mx.array(manager.val_acc_trace))

plot_results(results, optimizers, args)


if __name__ == "__main__":
args = parser.parse_args()
if args.cpu:
mx.set_default_device(mx.Device(mx.cpu))
main(args)
4 changes: 2 additions & 2 deletions examples/networks/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def normalize(x): # normalize to [0,1]


def cifar10(batch_size, img_size=(32, 32), root=None):
mean = np.array([0.485, 0.456, 0.406]).reshape((1, 1, 3))
std = np.array([0.229, 0.224, 0.225]).reshape((1, 1, 3))
mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 3))
std = np.array([0.2470, 0.2435, 0.2616]).reshape((1, 1, 3))

def normalize(x): # z-score normalize
x = x.astype("float32") / 255.0
Expand Down
3 changes: 2 additions & 1 deletion examples/networks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import mlx.core as mx
import mlx.nn as nn
from datasets import cifar10, mnist
from mlx.optimizers import Optimizer
from mlx.optimizers import Optimizer, clip_grad_norm
from models import Network
from tqdm import tqdm

Expand Down Expand Up @@ -44,6 +44,7 @@ def train(self, train, val=None, epochs: int = 10):
def step(X, T):
train_step_fn = nn.value_and_grad(self.model, self.eval_fn)
(loss, correct), grads = train_step_fn(X, T)
grads, _ = clip_grad_norm(grads, max_norm=1.0)
self.optimizer.update(self.model, grads)
return loss, correct

Expand Down
58 changes: 47 additions & 11 deletions examples/networks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,36 @@ def __call__(self, x: mx.array) -> mx.array:


class Network(Base):
"""Fully Connected / Convolutional Neural Network

Args:
n_inputs (Union[List[int], Tuple[int], mx.array]): Input shape
n_outputs (int): Number of output classes
conv_layers_list (List[dict], optional): List of convolutional layers. Defaults to [].
n_hiddens_list (Union[List, int], optional): List of hidden units. Defaults to 0.
activation_f (str, optional): Activation function. Defaults to "relu".
dropout (float, optional): Dropout rate. Defaults to 0.0.

conv_layers_list dict keys:
filters: int
kernel_size: int
stride: int
dilation: int
padding: int
bias: bool
batch_norm: bool
repeat: int

"""

def __init__(
self,
n_inputs: Union[List[int], Tuple[int], mx.array],
n_outputs: int,
conv_layers_list: List[dict] = [],
n_hiddens_list: Union[List, int] = 0,
activation_f: str = "relu",
dropout: float = 0.0,
):
super().__init__()

Expand Down Expand Up @@ -155,19 +178,30 @@ def __init__(

self.conv.extend(
[
nn.Conv2d(
n_channels,
conv_layer["filters"],
conv_layer["kernel_size"],
stride=conv_layer.get("stride", 1),
padding=padding,
dilation=conv_layer.get("dilation", 1),
bias=conv_layer.get("bias", True),
),
activation(),
nn.MaxPool2d(2, stride=2),
layer
for i in range(conv_layer.get("repeat", 1))
for layer in [
nn.Conv2d(
n_channels if i == 0 else conv_layer["filters"],
conv_layer["filters"],
conv_layer["kernel_size"],
stride=conv_layer.get("stride", 1),
padding=padding,
dilation=conv_layer.get("dilation", 1),
bias=conv_layer.get("bias", True),
),
activation(),
]
+ (
[nn.BatchNorm(conv_layer["filters"])]
if conv_layer.get("batch_norm")
else []
)
]
+ [nn.MaxPool2d(2, stride=2)]
)
if dropout > 0:
self.conv.append(nn.Dropout(dropout))
ni = mx.concatenate([ni[:-1] // 2, mx.array([conv_layer["filters"]])])

ni = int(mx.prod(ni))
Expand All @@ -176,6 +210,8 @@ def __init__(
for _, n_units in enumerate(n_hiddens_list):
self.fcn.append(nn.Linear(ni, n_units))
self.fcn.append(activation())
if dropout > 0:
self.fcn.append(nn.Dropout(dropout))
ni = n_units
self.output = nn.Linear(ni, n_outputs)

Expand Down
1 change: 1 addition & 0 deletions examples/rosenbrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def execute_experiments(optimizers, objective, func, plot_func, initial_state):
(optim.Muon, 0, 0.2, {"alternate_optimizer": AdamW(learning_rate=0.0842)}), # fixed lr
(optim.Shampoo, 0, 2, {}),
(optim.Kron, 0, 0.5, {}),
(optim.MARS, 0, 0.8, {"optimize_1d": True, "mars_type": "mars-adamw"}),
]
execute_experiments(
optimizers, objective_rosenbrock, rosenbrock, plot_rosenbrock, ROSENBROCK_INITIAL
Expand Down
2 changes: 2 additions & 0 deletions mlx_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .kron import Kron
from .lamb import Lamb
from .madgrad import MADGRAD
from .mars import MARS
from .muon import Muon
from .qhadam import QHAdam
from .shampoo import Shampoo
Expand All @@ -17,6 +18,7 @@
"Kron",
"Lamb",
"MADGRAD",
"MARS",
"Muon",
"QHAdam",
"Shampoo",
Expand Down
Loading
Loading