From 8a478c2c583c2e08737688ca1a15fcbeb4ad8f7e Mon Sep 17 00:00:00 2001 From: Amos You Date: Thu, 15 Feb 2024 01:32:11 -0800 Subject: [PATCH] additional linting --- optax/_src/wrappers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index 98ad33eb8..28d0cd148 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -571,7 +571,8 @@ class MaybeUpdateState(NamedTuple): """Maintains inner transform state and adds a step counter. Attributes: - inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`. + inner_state (:class:`optax.OptState`): The state of the inner + `GradientTransformation`. step (``Union[jax.Array, int]``): The current step counter. """ inner_state: base.OptState