Skip to content

Commit

Permalink
Merge pull request #4370 from 8bitmp3:update-nnx-fori_loop-docstring
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 697012688
  • Loading branch information
Flax Authors committed Nov 15, 2024
2 parents 9147a7c + 84fa22e commit 91bf758
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,7 +1493,7 @@ def fori_loop(lower: int, upper: int,
init_val: T,
*,
unroll: int | bool | None = None) -> T:
"""NNX transform of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_.
"""A Flax NNX transformation of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_.
Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the variable reference structure of `init_val` inside `body_fun`.
Expand All @@ -1515,21 +1515,21 @@ def fori_loop(lower: int, upper: int,
Args:
lower: an integer representing the loop index lower bound (inclusive)
upper: an integer representing the loop index upper bound (exclusive)
body_fun: a function that takes an input of type `T` and outputs an `T`.
Note that both data and modules of `T` must have the same reference
lower: An integer representing the loop index lower bound (inclusive).
upper: An integer representing the loop index upper bound (exclusive).
body_fun: a function that takes an input of type ``T`` and outputs an ``T``.
Note that both data and modules of ``T`` must have the same reference
structure between inputs and outputs.
init_val: the initial input for body_fun. Must be of type `T`.
init_val: the initial input for body_fun. Must be of type ``T``.
unroll: An optional integer or boolean that determines how much to unroll
the loop. If an integer is provided, it determines how many unrolled
loop iterations to run within a single rolled iteration of the loop. If a
boolean is provided, it will determine if the loop is competely unrolled
(i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`).
(i.e. ``unroll=True``) or left completely unrolled (i.e. ``unroll=False``).
This argument is only applicable if the loop bounds are statically known.
Returns:
Loop value from the final iteration, of type ``T``.
A loop value from the final iteration, of type ``T``.
"""

Expand Down

0 comments on commit 91bf758

Please sign in to comment.