Skip to content

Commit

Permalink
Merge pull request #970 from scap3yvt/968-feature-add-new-optimizers
Browse files Browse the repository at this point in the history
Added new optimizers
  • Loading branch information
sarthakpati authored Nov 21, 2024
2 parents d2f8c2f + 882fd51 commit 966f84f
Show file tree
Hide file tree
Showing 10 changed files with 600 additions and 27 deletions.
10 changes: 8 additions & 2 deletions GANDLF/optimizers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
## Adding a new algorithm

- For an optimizer defined in PyTorch [[ref](https://pytorch.org/docs/stable/optim.html#algorithms)], update the `GANDLF.optimizers.wrap_torch.py` submodule.
- For a custom optimizer, create a new submodule called `GANDLF.optimizers.${awesome_optimizer}.py`. Ensure that it inherits from PyTorch's base optimizer class [[ref](https://pytorch.org/docs/stable/optim.html#base-class)]
- For a custom optimizer, create a new submodule called `GANDLF.optimizers.${awesome_optimizer}.py`.
- For a third-party optimizer (i.e., where the code is available from an external source/repository):
- Add the relevant code under the `GANDLF.optimizers.thirdparty` submodule.
- Add a wrapper which takes in GaNDLF's `parameter` dictionary as input and creates a `torch.optim.Optimizer` object as output.
- Add the wrapper to the `GANDLF.optimizers.thirdparty.__init__.py` so that it can be called from `GANDLF.optimizers.__init__.py`.
- See `GANDLF.optimizers.thirdparty.adopy.py` as an example.
- If a new dependency needs to be used, update GaNDLF's [`setup.py`](https://github.com/mlcommons/GaNDLF/blob/master/setup.py) with the new requirement.
- Define a new submodule under `GANDLF.optimizers` as `GANDLF.optimizers.wrap_${package_name}.py`.
- Ensure that the new algorithm is wrapped in a function which returns an object with the PyTorch optimizer type. Use any of the optimizers in `GANDLF.optimizers.wrap_torch.py` as an example.
- Add the algorithm's identifier to `GANDLF.optimizers.__init__.global_optimizer_dict` with an appropriate key.
- Call the new algorithm from the config using the `optimizer` key.
- [Update the tests!](https://mlcommons.github.io/GaNDLF/extending/#update-tests)https://mlcommons.github.io/GaNDLF/extending/#update-tests
- [If appropriate, please update the tests!](https://mlcommons.github.io/GaNDLF/extending/#update-tests)https://mlcommons.github.io/GaNDLF/extending/#update-tests
- All wrappers should return the type `from torch.optim.optimizer.Optimizer`.
15 changes: 9 additions & 6 deletions GANDLF/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .wrap_monai import novograd_wrapper

from .ademamix import ademamix_wrapper
from .thirdparty import ademamix_wrapper, lion_wrapper, adopt_wrapper

global_optimizer_dict = {
"sgd": sgd,
Expand All @@ -32,6 +32,8 @@
"novograd": novograd_wrapper,
"nadam": nadam,
"ademamix": ademamix_wrapper,
"lion": lion_wrapper,
"adopt": adopt_wrapper,
}


Expand All @@ -49,9 +51,10 @@ def get_optimizer(params):
# Retrieve the optimizer type from the input parameters
optimizer_type = params["optimizer"]["type"]

assert (
optimizer_type in global_optimizer_dict
), f"Optimizer type {optimizer_type} not found"

# Create the optimizer instance using the specified type and input parameters
if optimizer_type in global_optimizer_dict:
optimizer_function = global_optimizer_dict[optimizer_type]
return optimizer_function(params)
else:
raise ValueError("Optimizer type %s not found" % optimizer_type)
optimizer_function = global_optimizer_dict[optimizer_type]
return optimizer_function(params)
5 changes: 5 additions & 0 deletions GANDLF/optimizers/thirdparty/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .ademamix import ademamix_wrapper

from .lion import lion_wrapper

from .adopt import adopt_wrapper
File renamed without changes.
Loading

0 comments on commit 966f84f

Please sign in to comment.