diff --git a/optax/_src/lookahead.py b/optax/_src/lookahead.py index 98d6103b5..5ef299424 100644 --- a/optax/_src/lookahead.py +++ b/optax/_src/lookahead.py @@ -30,8 +30,8 @@ class LookaheadState(NamedTuple): Attributes: fast_state (:class:`optax.OptState`): Optimizer state of the fast optimizer. - steps_since_sync (``Union[jax.Array, int]``): Number of fast optimizer steps taken since slow and fast - parameters were synchronized. + steps_since_sync (``Union[jax.Array, int]``): Number of fast optimizer steps + taken since slow and fast parameters were synchronized. """ fast_state: base.OptState steps_since_sync: Union[jax.Array, int] diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index b901a5228..28d0cd148 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -90,18 +90,17 @@ class ApplyIfFiniteState(NamedTuple): """State of the `GradientTransformation` returned by `apply_if_finite`. Attributes: - notfinite_count (``Union[jax.Array, int]``): Number of consecutive gradient updates containing an Inf or - a NaN. This number is reset to 0 whenever a gradient update without an Inf - or a NaN is done. - last_finite (``Union[jax.Array, int]``): Whether or not the last gradient update contained an Inf or a - NaN. - total_notfinite (``Union[jax.Array, int]``): Total number of gradient updates containing an Inf or - a NaN since this optimizer was initialised. This number is never reset. - inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`. + notfinite_count (``Union[jax.Array, int]``): Number of consecutive gradient + updates containing an Inf or a NaN. This number is reset to 0 whenever a + gradient update without an Inf or a NaN is done. + last_finite (``Union[jax.Array, int]``): Whether or not the last gradient + update contained an Inf or a NaN. + total_notfinite (``Union[jax.Array, int]``): Total number of gradient + updates containing an Inf or a NaN since this optimizer was initialised. + This number is never reset. + inner_state (:class:`optax.OptState`): The state of the inner + `GradientTransformation`. """ - # TODO(optax-dev): notfinite_count, last_finite and inner_state used to be - # annotated as `jnp.array` but that is not a valid annotation (it's a function - # and secretely resolved to `Any`. We should add back typing. notfinite_count: Union[jax.Array, int] last_finite: Union[jax.Array, int] total_notfinite: Union[jax.Array, int] @@ -176,16 +175,17 @@ class MultiStepsState(NamedTuple): """State of the `GradientTransformation` returned by `MultiSteps`. Attributes: - mini_step (``Union[jax.Array, int]``): current mini-step counter. At an update, this either increases by - 1 or is reset to 0. - gradient_step (``Union[jax.Array, int]``): gradient step counter. This only increases after enough - mini-steps have been accumulated. - inner_opt_state (:class:`optax.OptState`): the state of the wrapped optimiser. + mini_step (``Union[jax.Array, int]``): current mini-step counter. At an + update, this either increases by 1 or is reset to 0. + gradient_step (``Union[jax.Array, int]``): gradient step counter. This only + increases after enough mini-steps have been accumulated. + inner_opt_state (:class:`optax.OptState`): the state of the wrapped + optimiser. acc_grads (``jax.Array``): accumulated gradients over multiple mini-steps. - skip_state (``chex.ArrayTree``): an arbitrarily nested tree of arrays. This is only - relevant when passing a `should_skip_update_fn` to `MultiSteps`. This - structure will then contain values for debugging and or monitoring. The - actual structure will vary depending on the choice of + skip_state (``chex.ArrayTree``): an arbitrarily nested tree of arrays. This + is only relevant when passing a `should_skip_update_fn` to `MultiSteps`. + This structure will then contain values for debugging and or monitoring. + The actual structure will vary depending on the choice of `ShouldSkipUpdateFunction`. """ mini_step: Union[jax.Array, int] @@ -451,7 +451,8 @@ class MaskedState(NamedTuple): """Maintains inner transform state for masked transformations. Attributes: - inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`. + inner_state (:class:`optax.OptState`): The state of the inner + `GradientTransformation`. """ inner_state: base.OptState @@ -570,7 +571,8 @@ class MaybeUpdateState(NamedTuple): """Maintains inner transform state and adds a step counter. Attributes: - inner_state (:class:`optax.OptState`): The state of the inner `GradientTransformation`. + inner_state (:class:`optax.OptState`): The state of the inner + `GradientTransformation`. step (``Union[jax.Array, int]``): The current step counter. """ inner_state: base.OptState