Skip to content

Commit

Permalink
standardize SGD optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 10, 2024
1 parent bde2c1e commit b1c0997
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 36 deletions.
4 changes: 2 additions & 2 deletions brainpy/_src/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# -*- coding: utf-8 -*-

from .optimizer import *
from .scheduler import *
from .sgd_optimizer import *
from .sgd_scheduler import *
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@

import jax.numpy as jnp
from jax.lax import cond
import brainpy as bp

import brainpy.math as bm
from brainpy import check
from brainpy._src.math.object_transform.base import BrainPyObject, ArrayCollector
from brainpy.errors import MathError
from .scheduler import make_schedule, Scheduler
from .sgd_scheduler import make_schedule, Scheduler

__all__ = [
'Optimizer',
Expand All @@ -28,7 +26,7 @@
]


class Optimizer(BrainPyObject):
class Optimizer(bm.BrainPyObject):
"""Base Optimizer Class.
Parameters
Expand All @@ -40,7 +38,7 @@ class Optimizer(BrainPyObject):
lr: Scheduler # learning rate
'''Learning rate'''

vars_to_train: ArrayCollector # variables to train
vars_to_train: bm.VarDict # variables to train
'''Variables to train.'''

def __init__(
Expand All @@ -49,9 +47,9 @@ def __init__(
train_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None,
name: Optional[str] = None
):
super(Optimizer, self).__init__(name=name)
super().__init__(name=name)
self.lr: Scheduler = make_schedule(lr)
self.vars_to_train = ArrayCollector()
self.vars_to_train = bm.var_dict()
self.register_train_vars(train_vars)

def register_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None):
Expand All @@ -72,23 +70,27 @@ def __repr__(self):
def update(self, grads: dict):
raise NotImplementedError

def zero_grad(self):
"""
Zero the gradients of all trainable variables.
"""
for p in self.vars_to_train.values():
p.value = jnp.zeros_like(p.value)

class CommonOpt(Optimizer):

class _CommonOpt(Optimizer):
def __init__(
self,
lr: Union[float, Scheduler, bm.Variable],
train_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None,
weight_decay: Optional[float] = None,
name: Optional[str] = None
):
super(Optimizer, self).__init__(name=name)
self.lr: Scheduler = make_schedule(lr)
self.vars_to_train = ArrayCollector()
self.register_train_vars(train_vars)
super().__init__(name=name, lr=lr, train_vars=train_vars)
self.weight_decay = check.is_float(weight_decay, min_bound=0., max_bound=1., allow_none=True)


class SGD(CommonOpt):
class SGD(_CommonOpt):
r"""Stochastic gradient descent optimizer.
SGD performs a parameter update for training examples :math:`x` and label
Expand Down Expand Up @@ -138,7 +140,7 @@ def update(self, grads: dict):
self.lr.step_call()


class Momentum(CommonOpt):
class Momentum(_CommonOpt):
r"""Momentum optimizer.
Momentum [1]_ is a method that helps accelerate SGD in the relevant direction
Expand Down Expand Up @@ -209,7 +211,7 @@ def update(self, grads: dict):
self.lr.step_call()


class MomentumNesterov(CommonOpt):
class MomentumNesterov(_CommonOpt):
r"""Nesterov accelerated gradient optimizer [2]_.
.. math::
Expand Down Expand Up @@ -273,7 +275,7 @@ def update(self, grads: dict):
self.lr.step_call()


class Adagrad(CommonOpt):
class Adagrad(_CommonOpt):
r"""Optimizer that implements the Adagrad algorithm.
Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are
Expand Down Expand Up @@ -345,7 +347,7 @@ def __repr__(self):
return f"{self.__class__.__name__}(lr={self.lr}, epsilon={self.epsilon})"


class Adadelta(CommonOpt):
class Adadelta(_CommonOpt):
r"""Optimizer that implements the Adadelta algorithm.
Adadelta [4]_ optimization is a stochastic gradient descent method that is based
Expand Down Expand Up @@ -437,7 +439,7 @@ def __repr__(self):
f"epsilon={self.epsilon}, rho={self.rho})")


class RMSProp(CommonOpt):
class RMSProp(_CommonOpt):
r"""Optimizer that implements the RMSprop algorithm.
RMSprop [5]_ and Adadelta have both been developed independently around the same time
Expand Down Expand Up @@ -513,7 +515,7 @@ def __repr__(self):
f"epsilon={self.epsilon}, rho={self.rho})")


class Adam(CommonOpt):
class Adam(_CommonOpt):
"""Optimizer that implements the Adam algorithm.
Adam [6]_ - a stochastic gradient descent method (SGD) that computes
Expand Down Expand Up @@ -598,7 +600,7 @@ def update(self, grads: dict):
self.lr.step_call()


class LARS(CommonOpt):
class LARS(_CommonOpt):
r"""Layer-wise adaptive rate scaling (LARS) optimizer [1]_.
Layer-wise Adaptive Rate Scaling, or LARS, is a large batch
Expand Down Expand Up @@ -678,7 +680,7 @@ def update(self, grads: dict):
self.lr.step_call()


class Adan(CommonOpt):
class Adan(_CommonOpt):
r"""Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_.
.. math::
Expand Down Expand Up @@ -817,7 +819,7 @@ def update(self, grads: dict):
self.lr.step_call()


class AdamW(CommonOpt):
class AdamW(_CommonOpt):
r"""Adam with weight decay regularization [1]_.
AdamW uses weight decay to regularize learning towards small weights, as
Expand Down Expand Up @@ -977,7 +979,7 @@ def update(self, grads: dict):
self.lr.step_call()


class SM3(CommonOpt):
class SM3(_CommonOpt):
"""SM3 algorithm [1]_.
The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
# -*- coding: utf-8 -*-
import warnings
from functools import partial
from typing import Sequence, Union

import jax
import jax.numpy as jnp

import brainpy.math as bm
from brainpy import check
from brainpy._src.math.object_transform.base import BrainPyObject
from brainpy.errors import MathError


Expand All @@ -25,7 +23,7 @@ def make_schedule(scalar_or_schedule):
raise TypeError(type(scalar_or_schedule))


class Scheduler(BrainPyObject):
class Scheduler(bm.BrainPyObject):
"""The learning rate scheduler."""

def __init__(self, lr: Union[float, bm.Variable], last_epoch: int = -1):
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/optimizers/skopt_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

from .base import Optimizer

__all__ = ['SkBayesianOptimizer']
__all__ = ['SkoptBayesOptimizer']


class SkBayesianOptimizer(Optimizer):
class SkoptBayesOptimizer(Optimizer):
"""
SkoptOptimizer instance creates all the tools necessary for the user
to use it with scikit-optimize library.
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/optimizers/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from absl.testing import parameterized

import brainpy.math as bm
from brainpy._src.optimizers import scheduler
from brainpy._src.optimizers import sgd_scheduler

show = False

Expand Down
10 changes: 5 additions & 5 deletions brainpy/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
# ---------- #


from brainpy._src.optimizers.optimizer import (
from brainpy._src.optimizers.sgd_optimizer import (
Optimizer as Optimizer,
)
from brainpy._src.optimizers.optimizer import (
from brainpy._src.optimizers.sgd_optimizer import (
SGD as SGD,
Momentum as Momentum,
MomentumNesterov as MomentumNesterov,
Expand All @@ -26,11 +26,11 @@
# ---------- #


from brainpy._src.optimizers.scheduler import (
from brainpy._src.optimizers.sgd_scheduler import (
make_schedule as make_schedule,
Scheduler as Scheduler,
)
from brainpy._src.optimizers.scheduler import (
from brainpy._src.optimizers.sgd_scheduler import (
Constant as Constant,
ExponentialDecay as ExponentialDecay,
InverseTimeDecay as InverseTimeDecay,
Expand All @@ -41,7 +41,7 @@
InverseTimeDecayLR as InverseTimeDecayLR,
ExponentialDecayLR as ExponentialDecayLR
)
from brainpy._src.optimizers.scheduler import (
from brainpy._src.optimizers.sgd_scheduler import (
StepLR as StepLR,
MultiStepLR as MultiStepLR,
ExponentialLR as ExponentialLR,
Expand Down

0 comments on commit b1c0997

Please sign in to comment.