Skip to content

Commit

Permalink
figured out a good way to support lr scheduling and per-parameter set…
Browse files Browse the repository at this point in the history
…tings
  • Loading branch information
inikishev committed Jan 15, 2025
1 parent 48c071b commit bfbf3a2
Show file tree
Hide file tree
Showing 39 changed files with 563 additions and 280 deletions.
75 changes: 53 additions & 22 deletions src/torchzero/core/module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, TypeAlias
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence, Iterable
import warnings

import torch

Expand Down Expand Up @@ -28,19 +29,19 @@ def closure(backward = True):
This closure will also work with all built in pytorch optimizers including LBFGS, as well as and most custom ones.
"""

class _WeakDict(dict): pass
# class _WeakDict(dict): pass

def _get_param_groups_to_pass_to_child(optimizer: torch.optim.Optimizer):
"""propagate only per-parameter settings that are not in optimizer.defaults"""
param_groups: list[dict[str, Any]] = []
for g in optimizer.param_groups.copy():
child_g = {"params": g["params"]}
for k,v in g.copy().items():
if k not in optimizer.defaults:
child_g[k] = v
param_groups.append(child_g)
# def _get_param_groups_to_pass_to_child(optimizer: torch.optim.Optimizer):
# """propagate only per-parameter settings that are not in optimizer.defaults"""
# param_groups: list[dict[str, Any]] = []
# for g in optimizer.param_groups.copy():
# child_g = {"params": g["params"]}
# for k,v in g.copy().items():
# if k not in optimizer.defaults:
# child_g[k] = v
# param_groups.append(child_g)

return param_groups
# return param_groups

class OptimizationState:
"""Holds optimization state. This is usually automatically created by :any:`torchzero.optim.Modular`."""
Expand Down Expand Up @@ -147,21 +148,41 @@ class OptimizerModule(TensorListOptimizer, ABC):
make_closure (bool, optional):
if True, :any:`_update` method functions as a closure,
otherwise it updates the ascent directly. Defaults to False.
Only has effect when overriding _update.
"""
IS_LR_MODULE = False
def __init__(self, defaults: dict[str, Any], make_closure = False): # pylint:disable = super-init-not-called
# there can only be 1 LR module, which is placed in the appropriate location among other modules.
# scheduling and per-parameter "lr" options will be routed to that module.
# otherwise, since many update rules like Adam have baked in lr, if multiple such modules are used,
# any lr modification gets applied multiple times.
# Some optimizers will automatically be fused if followed an LR() module (only LR specifically is supported).
if not self.IS_LR_MODULE:
if 'lr' in defaults:
warnings.warn(
f'{self.__class__.__name__} got an "lr" default, but it is not an LR module.\
To support lr scheduling and per-parameter options, rename "lr" to "alpha" and set the default value to 1.\
If this is a learning rate module, set a class attribute `IS_LR_MODULE=True`.'
)

self._defaults = defaults
self.next_module: OptimizerModule | None = None
"""next module that takes this module's state and continues working on it."""
self.children: dict[Any, OptimizerModule] = {}
"""children modules."""
self._initialized = False
"""True if torch.optim.Optimzer.__init__ was called on this meaning this optimizer has parameters."""
self._make_closure = make_closure
"""if True, :any:`_update` method functions as a closure, otherwise it updates the ascent directly"""

self._has_custom_params = False
"""Signifies that :any:`self.set_params` was called on this to set custom params.
When this is True, when parent calls :any:`_update_child_params_` with this module as child,
nothing will happen, as this module already has parameters set."""

self._passed_params: list[torch.Tensor] | list[dict[str, Any]] | None = None
"""list of parameters or parameter groups that were passed to this module and will get passed to child modules."""

def __repr__(self):
if self._initialized: return super().__repr__()
return f"uninitialized {self.__class__.__name__}()"
Expand All @@ -170,20 +191,21 @@ def set_params(self, params: ParamsT):
"""
Set parameters to this module. Use this to set per-parameter group settings.
"""
self._initialize_(params)
self._initialize_(params, set_passed_params = False)
self._has_custom_params = True
return self

def _initialize_(self, params: ParamsT):
def _initialize_(self, params: ParamsT, set_passed_params: bool):
"""Initializes this optimizer and all children with the given parameters."""
if isinstance(params, torch.Tensor): raise ValueError("Params must be an iterable of tensors, not torch.Tensor")
params = list(params) # type:ignore
params_list = list(params)
if set_passed_params: self._passed_params = params_list.copy() # type:ignore

# super().__init__, which is torch.optim.Optimizer.__init__,
# calls self.add_param_group on each param group,
# which in turn calls _update_child_params_,
# which calls add_param_group on each child.
super().__init__(params, self._defaults)
super().__init__(params_list.copy(), self._defaults) # type:ignore
self._initialized = True

def _set_child_(self, name, child: "_Chainable"):
Expand All @@ -208,21 +230,27 @@ def _update_next_module_params_(self, next_module: "OptimizerModule"):
# Shouldn't forget that this method is overwritten by some modules
# So if I update it I need to keep that in mind

if self._passed_params is None:
raise RuntimeError(
f"{self.__class__.__name__} is not initialized, but _update_next_module_params_\
was called with next_module = {next_module.__class__.__name__}"
)

# if child is not initialized, torch.optim.Optimizer.__init__ is called on it by _initialize_ method
if not next_module._initialized:

# propagate only per-parameter settings that are not in self.defaults
next_module._initialize_(_get_param_groups_to_pass_to_child(self))
next_module._initialize_(self._passed_params, set_passed_params=True)

# otherwise to avoid calling __init__ multiple twice, we erase the param groups and readd them
elif not next_module._has_custom_params:
next_module.param_groups = []
# it is important not to propagate all the settings
# for example if this module has `lr` setting, and the child has a different `lr` setting,
# we don't want to overwrite the child's `lr` setting
for group in _get_param_groups_to_pass_to_child(self):
for group in self._passed_params:
if isinstance(group, torch.Tensor): group = {"params": group}
next_module.add_param_group(group)

else:
# still pass per-parameter settings so that they propagate to further modules
next_module._passed_params = self._passed_params.copy()


def add_param_group(self, param_group: dict[str, Any]) -> None:
super().add_param_group(param_group)
Expand Down Expand Up @@ -377,6 +405,9 @@ class _Chain(OptimizerModule):
def __init__(self, *modules: _Chainable):
super().__init__({})
flat_modules: list[OptimizerModule] = flatten(modules)
if any(not hasattr(i, "step") for i in flat_modules):
raise TypeError(f"One of the modules is not an OptimizerModule, got {[i.__class__.__name__ for i in flat_modules]}")

self._ascent_returner = _MaybeReturnAscent()
flat_modules.append(self._ascent_returner)

Expand Down
6 changes: 2 additions & 4 deletions src/torchzero/core/tensorlist_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,12 @@ def get_state_key[CLS: MutableSequence](self, key: str, init: _StateInit = torch

def get_state_keys[CLS: MutableSequence](
self,
keys: Sequence[str],
*keys: str,
inits: _StateInit | Sequence[_StateInit] = torch.zeros_like,
params=None,
cls: type[CLS] = TensorList,
) -> list[CLS]:
"""Returns a TensorList with the `key` states of all `params`. Creates the states if they don't exist."""
if isinstance(keys, str): raise TypeError('keys must be a sequence of strings')

values = [cls() for _ in range(len(keys))]
if params is None: params = self.get_params()
Expand Down Expand Up @@ -116,9 +115,8 @@ def get_all_group_keys[CLS: Any](self, cls: type[CLS] = NumberList) -> dict[str,

return all_values

def get_group_keys[CLS: MutableSequence](self, keys: Sequence[str], cls: type[CLS] = NumberList) -> list[CLS]:
def get_group_keys[CLS: MutableSequence](self, *keys: str, cls: type[CLS] = NumberList) -> list[CLS]:
"""Returns a TensorList with the param_groups `key` setting of each param."""
if isinstance(keys, str): raise TypeError('keys must be a sequence of strings')

all_values: list[CLS] = [cls() for _ in keys]
for group in self.param_groups:
Expand Down
12 changes: 6 additions & 6 deletions src/torchzero/modules/adaptive/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,18 @@ class ScaleLRBySignChange(OptimizerModule):
This is part of RProp update rule.
Args:
lr (float): initial learning rate.
nplus (float): learning rate gets multiplied by `nplus` if ascent/gradient didn't change the sign
nminus (float): learning rate gets multiplied by `nminus` if ascent/gradient changed the sign
lb (float): lower bound for lr.
ub (float): upper bound for lr.
alpha (float): initial learning rate.
.. warning::
If `use_grad` is True and you use this after modules that estimate gradients, e.g. FDM,
they need to have `make_closure` set to True so that they write to `grad` attribute.
"""
def __init__(self, lr: float = 1, nplus: float = 1.2, nminus: float = 0.5, lb = 1e-6, ub = 50, use_grad=False):
defaults = dict(nplus = nplus, nminus = nminus, lr = lr, lb = lb, ub = ub)
def __init__(self, nplus: float = 1.2, nminus: float = 0.5, lb = 1e-6, ub = 50, alpha=1, use_grad=False):
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub)
super().__init__(defaults)
self.current_step = 0
self.use_grad = use_grad
Expand All @@ -130,12 +130,12 @@ def _update(self, state, ascent):
if self.use_grad: cur = state.maybe_compute_grad_(params)
else: cur = ascent

nplus, nminus, lb, ub = self.get_group_keys(['nplus', 'nminus', 'lb', 'ub'])
prev, lrs = self.get_state_keys(['prev_ascent', 'lrs'], params=params)
nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
prev, lrs = self.get_state_keys('prev_ascent', 'lrs', params=params)

# initialize on 1st step
if self.current_step == 0:
lrs.fill_(self.defaults['lr'])
lrs.fill_(self.get_group_key('alpha'))
ascent.mul_(lrs)
prev.copy_(ascent)
self.current_step += 1
Expand Down
44 changes: 28 additions & 16 deletions src/torchzero/modules/experimental/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ class MinibatchRprop(OptimizerModule):
"""
def __init__(
self,
lr: float = 1,
nplus: float = 1.2,
nminus: float = 0.5,
lb: float | None = 1e-6,
ub: float | None = 50,
backtrack=True,
next_mode = 'continue',
increase_mul = 0.5,
alpha: float = 1,
):
defaults = dict(nplus = nplus, nminus = nminus, lr = lr, lb = lb, ub = ub, increase_mul=increase_mul)
defaults = dict(nplus = nplus, nminus = nminus, alpha = alpha, lb = lb, ub = ub, increase_mul=increase_mul)
super().__init__(defaults)
self.current_step = 0
self.backtrack = backtrack
Expand All @@ -40,17 +40,17 @@ def step(self, state):
if state.ascent is not None: raise ValueError("Minibatch Rprop must be the first module.")
params = self.get_params()

nplus, nminus, lb, ub = self.get_group_keys(['nplus', 'nminus', 'lb', 'ub'])
nplus, nminus, lb, ub = self.get_group_keys('nplus', 'nminus', 'lb', 'ub')
allowed, magnitudes = self.get_state_keys(
['allowed', 'magnitudes'],
'allowed', 'magnitudes',
inits = [_bool_ones_like, torch.zeros_like],
params=params
)

g1_sign = state.maybe_compute_grad_(params).sign() # no inplace to not modify grads
# initialize on 1st iteration
if self.current_step == 0:
magnitudes.fill_(self.defaults['lr']).clamp_(lb, ub)
magnitudes.fill_(self.get_group_key('alpha')).clamp_(lb, ub)
# ascent = magnitudes * g1_sign
# self.current_step += 1
# return ascent
Expand Down Expand Up @@ -135,7 +135,7 @@ class GradMin(OptimizerModule):
explanation: calculate grads wrt sum of grads + loss.
"""
def __init__(self, loss_term: float = 1, square=False, maximize_grad = False):
super().__init__(dict(add_loss=loss_term))
super().__init__(dict(loss_term=loss_term))
self.square = square
self.maximize_grad = maximize_grad

Expand All @@ -146,7 +146,7 @@ def step(self, state):
raise ValueError("GradMin doesn't accept ascent_direction")

params = self.get_params()
add_loss = self.get_group_key('add_loss')
loss_term = self.get_group_key('loss_term')

self.zero_grad()
with torch.enable_grad():
Expand All @@ -158,8 +158,8 @@ def step(self, state):
else:
grads = grads.abs()

if self.maximize_grad: grads: TensorList = grads - (state.fx0 * add_loss) # type:ignore
else: grads = grads + (state.fx0 * add_loss)
if self.maximize_grad: grads: TensorList = grads - (state.fx0 * loss_term) # type:ignore
else: grads = grads + (state.fx0 * loss_term)
grad_mean = torch.sum(torch.stack(grads.sum())) / grads.total_numel()
grad_mean.backward(retain_graph=False)

Expand Down Expand Up @@ -203,7 +203,7 @@ def step(self, state):


def _reset_stats_hook(optimizer, state):
for module in optimizer.modules:
for module in optimizer.unrolled_modules:
module: OptimizerModule
module.reset_stats()

Expand All @@ -215,15 +215,26 @@ class CyclicSWA(OptimizerModule):
and next cycle starts.
It is easier to tune than PeriodicSWA and seems to work better too.
"""
def __init__(self, cswa_start: int, cycle_length: int, steps_between: int, init_lr: float = 0, peak_lr: float = 1):
super().__init__({})
Args:
cswa_start (int): number of steps before starting the first CSWA cycle.
cycle_length (int): length of each cycle in steps.
steps_between (int): number of steps between cycles.
init_lr (float, optional): initial and final learning rate in each cycle. Defaults to 0.
peak_lr (float, optional): peak learning rate of each cycle. Defaults to 1.
reset_stats (bool, optional):
if True, when setting model parameters to SWA, resets other modules stats such as momentum velocities (default: True).
"""
def __init__(self, cswa_start: int, cycle_length: int, steps_between: int, init_lr: float = 0, peak_lr: float = 1, reset_stats: bool=True):
defaults = dict(init_lr = init_lr, peak_lr = peak_lr)
super().__init__(defaults)
self.cswa_start = cswa_start
self.cycle_length = cycle_length
self.init_lr = init_lr
self.peak_lr = peak_lr
self.steps_between = steps_between
self._reset_stats = reset_stats

self.cur = 0
self.cycle_cur = 0
Expand All @@ -241,12 +252,13 @@ def step(self, state):

# determine the lr
point = self.cycle_cur / self.cycle_length
init_lr, peak_lr = self.get_group_keys('init_lr', 'peak_lr')
if point < 0.5:
p2 = point*2
lr = self.init_lr * (1-p2) + self.peak_lr * p2
lr = init_lr * (1-p2) + peak_lr * p2
else:
p2 = (1 - point)*2
lr = self.init_lr * (1-p2) + self.peak_lr * p2
lr = init_lr * (1-p2) + peak_lr * p2

ascent *= lr
ret = self._update_params_or_step_with_next(state, params)
Expand All @@ -262,7 +274,7 @@ def step(self, state):
self.cycle_cur = -1

params.set_(swa)
state.add_post_step_hook(_reset_stats_hook)
if self._reset_stats: state.add_post_step_hook(_reset_stats_hook)

self.cycle_cur += 1

Expand Down
3 changes: 2 additions & 1 deletion src/torchzero/modules/experimental/subspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def _update_child_params_(self, child: "OptimizerModule"):
dtype = self._params[0].dtype
device = self._params[0].device
params = [torch.zeros(sum(proj.n for proj in self.projections), dtype = dtype, device = device, requires_grad=True)]
if child._has_custom_params: raise RuntimeError(f"Subspace child {child.__class__.__name__} can't have custom params.")
if not child._initialized:
child._initialize_(params)
child._initialize_(params, set_passed_params=False)
else:
child.param_groups = []
child.add_param_group({"params": params})
Expand Down
2 changes: 1 addition & 1 deletion src/torchzero/modules/line_search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..regularization import Normalize
from .grid_ls import (ArangeLS, BacktrackingLS, GridLS, LinspaceLS,
MultiplicativeLS)
from .quad_interp import QuadraticInterpolation2Point
# from .quad_interp import QuadraticInterpolation2Point
from .directional_newton import DirectionalNewton3Points, DirectionalNewton
from .scipy_minimize_scalar import ScipyMinimizeScalarLS
from .armijo import ArmijoLS
Expand Down
Loading

0 comments on commit bfbf3a2

Please sign in to comment.