diff --git a/brainpy/_src/optimizers/__init__.py b/brainpy/_src/optimizers/__init__.py index ed3b22c6b..21cbe7c6e 100644 --- a/brainpy/_src/optimizers/__init__.py +++ b/brainpy/_src/optimizers/__init__.py @@ -1,4 +1,4 @@ # -*- coding: utf-8 -*- -from .optimizer import * -from .scheduler import * +from .sgd_optimizer import * +from .sgd_scheduler import * diff --git a/brainpy/_src/optimizers/nevergrad_optimizer.py b/brainpy/_src/optimizers/nevergrad_optimizer.py new file mode 100644 index 000000000..e69de29bb diff --git a/brainpy/_src/optimizers/optimizer.py b/brainpy/_src/optimizers/sgd_optimizer.py similarity index 98% rename from brainpy/_src/optimizers/optimizer.py rename to brainpy/_src/optimizers/sgd_optimizer.py index c2aec25a0..536b97195 100644 --- a/brainpy/_src/optimizers/optimizer.py +++ b/brainpy/_src/optimizers/sgd_optimizer.py @@ -5,13 +5,11 @@ import jax.numpy as jnp from jax.lax import cond -import brainpy as bp import brainpy.math as bm from brainpy import check -from brainpy._src.math.object_transform.base import BrainPyObject, ArrayCollector from brainpy.errors import MathError -from .scheduler import make_schedule, Scheduler +from .sgd_scheduler import make_schedule, Scheduler __all__ = [ 'Optimizer', @@ -28,7 +26,7 @@ ] -class Optimizer(BrainPyObject): +class Optimizer(bm.BrainPyObject): """Base Optimizer Class. Parameters @@ -40,7 +38,7 @@ class Optimizer(BrainPyObject): lr: Scheduler # learning rate '''Learning rate''' - vars_to_train: ArrayCollector # variables to train + vars_to_train: bm.VarDict # variables to train '''Variables to train.''' def __init__( @@ -49,9 +47,9 @@ def __init__( train_vars: Union[Sequence[bm.Variable], Dict[str, bm.Variable]] = None, name: Optional[str] = None ): - super(Optimizer, self).__init__(name=name) + super().__init__(name=name) self.lr: Scheduler = make_schedule(lr) - self.vars_to_train = ArrayCollector() + self.vars_to_train = bm.var_dict() self.register_train_vars(train_vars) def register_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None): @@ -72,8 +70,15 @@ def __repr__(self): def update(self, grads: dict): raise NotImplementedError + def zero_grad(self): + """ + Zero the gradients of all trainable variables. + """ + for p in self.vars_to_train.values(): + p.value = jnp.zeros_like(p.value) -class CommonOpt(Optimizer): + +class _CommonOpt(Optimizer): def __init__( self, lr: Union[float, Scheduler, bm.Variable], @@ -81,14 +86,11 @@ def __init__( weight_decay: Optional[float] = None, name: Optional[str] = None ): - super(Optimizer, self).__init__(name=name) - self.lr: Scheduler = make_schedule(lr) - self.vars_to_train = ArrayCollector() - self.register_train_vars(train_vars) + super().__init__(name=name, lr=lr, train_vars=train_vars) self.weight_decay = check.is_float(weight_decay, min_bound=0., max_bound=1., allow_none=True) -class SGD(CommonOpt): +class SGD(_CommonOpt): r"""Stochastic gradient descent optimizer. SGD performs a parameter update for training examples :math:`x` and label @@ -138,7 +140,7 @@ def update(self, grads: dict): self.lr.step_call() -class Momentum(CommonOpt): +class Momentum(_CommonOpt): r"""Momentum optimizer. Momentum [1]_ is a method that helps accelerate SGD in the relevant direction @@ -209,7 +211,7 @@ def update(self, grads: dict): self.lr.step_call() -class MomentumNesterov(CommonOpt): +class MomentumNesterov(_CommonOpt): r"""Nesterov accelerated gradient optimizer [2]_. .. math:: @@ -273,7 +275,7 @@ def update(self, grads: dict): self.lr.step_call() -class Adagrad(CommonOpt): +class Adagrad(_CommonOpt): r"""Optimizer that implements the Adagrad algorithm. Adagrad [3]_ is an optimizer with parameter-specific learning rates, which are @@ -345,7 +347,7 @@ def __repr__(self): return f"{self.__class__.__name__}(lr={self.lr}, epsilon={self.epsilon})" -class Adadelta(CommonOpt): +class Adadelta(_CommonOpt): r"""Optimizer that implements the Adadelta algorithm. Adadelta [4]_ optimization is a stochastic gradient descent method that is based @@ -437,7 +439,7 @@ def __repr__(self): f"epsilon={self.epsilon}, rho={self.rho})") -class RMSProp(CommonOpt): +class RMSProp(_CommonOpt): r"""Optimizer that implements the RMSprop algorithm. RMSprop [5]_ and Adadelta have both been developed independently around the same time @@ -513,7 +515,7 @@ def __repr__(self): f"epsilon={self.epsilon}, rho={self.rho})") -class Adam(CommonOpt): +class Adam(_CommonOpt): """Optimizer that implements the Adam algorithm. Adam [6]_ - a stochastic gradient descent method (SGD) that computes @@ -598,7 +600,7 @@ def update(self, grads: dict): self.lr.step_call() -class LARS(CommonOpt): +class LARS(_CommonOpt): r"""Layer-wise adaptive rate scaling (LARS) optimizer [1]_. Layer-wise Adaptive Rate Scaling, or LARS, is a large batch @@ -678,7 +680,7 @@ def update(self, grads: dict): self.lr.step_call() -class Adan(CommonOpt): +class Adan(_CommonOpt): r"""Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models [1]_. .. math:: @@ -817,7 +819,7 @@ def update(self, grads: dict): self.lr.step_call() -class AdamW(CommonOpt): +class AdamW(_CommonOpt): r"""Adam with weight decay regularization [1]_. AdamW uses weight decay to regularize learning towards small weights, as @@ -977,7 +979,7 @@ def update(self, grads: dict): self.lr.step_call() -class SM3(CommonOpt): +class SM3(_CommonOpt): """SM3 algorithm [1]_. The 'Square-root of Minima of Sums of Maxima of Squared-gradients Method' diff --git a/brainpy/_src/optimizers/scheduler.py b/brainpy/_src/optimizers/sgd_scheduler.py similarity index 99% rename from brainpy/_src/optimizers/scheduler.py rename to brainpy/_src/optimizers/sgd_scheduler.py index b27398dae..d2f47da8e 100644 --- a/brainpy/_src/optimizers/scheduler.py +++ b/brainpy/_src/optimizers/sgd_scheduler.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import warnings -from functools import partial from typing import Sequence, Union import jax @@ -8,7 +7,6 @@ import brainpy.math as bm from brainpy import check -from brainpy._src.math.object_transform.base import BrainPyObject from brainpy.errors import MathError @@ -25,7 +23,7 @@ def make_schedule(scalar_or_schedule): raise TypeError(type(scalar_or_schedule)) -class Scheduler(BrainPyObject): +class Scheduler(bm.BrainPyObject): """The learning rate scheduler.""" def __init__(self, lr: Union[float, bm.Variable], last_epoch: int = -1): diff --git a/brainpy/_src/optimizers/skopt_bayesian.py b/brainpy/_src/optimizers/skopt_bayesian.py index a6f32f0c6..4482d0624 100644 --- a/brainpy/_src/optimizers/skopt_bayesian.py +++ b/brainpy/_src/optimizers/skopt_bayesian.py @@ -4,10 +4,10 @@ from .base import Optimizer -__all__ = ['SkBayesianOptimizer'] +__all__ = ['SkoptBayesOptimizer'] -class SkBayesianOptimizer(Optimizer): +class SkoptBayesOptimizer(Optimizer): """ SkoptOptimizer instance creates all the tools necessary for the user to use it with scikit-optimize library. diff --git a/brainpy/_src/optimizers/tests/test_scheduler.py b/brainpy/_src/optimizers/tests/test_scheduler.py index dbdda0eda..8c53f33dd 100644 --- a/brainpy/_src/optimizers/tests/test_scheduler.py +++ b/brainpy/_src/optimizers/tests/test_scheduler.py @@ -7,7 +7,7 @@ from absl.testing import parameterized import brainpy.math as bm -from brainpy._src.optimizers import scheduler +from brainpy._src.optimizers import sgd_scheduler show = False diff --git a/brainpy/optim.py b/brainpy/optim.py index de66e3700..66419ddd4 100644 --- a/brainpy/optim.py +++ b/brainpy/optim.py @@ -5,10 +5,10 @@ # ---------- # -from brainpy._src.optimizers.optimizer import ( +from brainpy._src.optimizers.sgd_optimizer import ( Optimizer as Optimizer, ) -from brainpy._src.optimizers.optimizer import ( +from brainpy._src.optimizers.sgd_optimizer import ( SGD as SGD, Momentum as Momentum, MomentumNesterov as MomentumNesterov, @@ -26,11 +26,11 @@ # ---------- # -from brainpy._src.optimizers.scheduler import ( +from brainpy._src.optimizers.sgd_scheduler import ( make_schedule as make_schedule, Scheduler as Scheduler, ) -from brainpy._src.optimizers.scheduler import ( +from brainpy._src.optimizers.sgd_scheduler import ( Constant as Constant, ExponentialDecay as ExponentialDecay, InverseTimeDecay as InverseTimeDecay, @@ -41,7 +41,7 @@ InverseTimeDecayLR as InverseTimeDecayLR, ExponentialDecayLR as ExponentialDecayLR ) -from brainpy._src.optimizers.scheduler import ( +from brainpy._src.optimizers.sgd_scheduler import ( StepLR as StepLR, MultiStepLR as MultiStepLR, ExponentialLR as ExponentialLR,