Skip to content

Commit

Permalink
Random seed doc (#20575)
Browse files Browse the repository at this point in the history
* Fixed error in doc of random number generators concerning seed argument.

* Update class documentation for SeedGenerator

Clarify the facts that 

- a global SeedGenerator is used by all random number generating functions in keras,
- a SeedGenerator is required for jit compilation with the JAX backend.

* Minor reformulation.

* Refined remark on JAX and tracing.

* Fixed column length.

* Fixed line length of documentation.

* Reformatted with black.

* Reformatted with black.

* Still some lines too long?

* Another long column that was intrduced by black.
  • Loading branch information
roebel authored Dec 2, 2024
1 parent b02085b commit bcdb8e4
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 80 deletions.
189 changes: 117 additions & 72 deletions keras/src/random/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
supported. If not specified, `keras.config.floatx()` is used,
which defaults to `float32` unless you configured it otherwise (via
`keras.config.set_floatx(float_dtype)`).
seed: A Python integer or instance of
`keras.random.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras.random.SeedGenerator`.
seed: Optional Python integer or instance of
`keras.random.SeedGenerator`.
By default, the `seed` argument is `None`, and an internal global
`keras.random.SeedGenerator` is used. The `seed` argument can be
used to ensure deterministic (repeatable) random number generation.
Note that passing an integer as the `seed` value will produce the
same random values for each call. To generate different random
values for repeated calls, an instance of
`keras.random.SeedGenerator` must be provided as the `seed` value.
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global `keras.random.SeedGenerator` is not
supported. Therefore, during tracing the default value seed=None
will produce an error, and a `seed` argument must be provided.
"""
return backend.random.normal(
shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
Expand Down Expand Up @@ -51,14 +56,19 @@ def categorical(logits, num_samples, dtype="int32", seed=None):
row of the input. This will be the second dimension of the output
tensor's shape.
dtype: Optional dtype of the output tensor.
seed: A Python integer or instance of
`keras.random.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras.random.SeedGenerator`.
seed: Optional Python integer or instance of
`keras.random.SeedGenerator`.
By default, the `seed` argument is `None`, and an internal global
`keras.random.SeedGenerator` is used. The `seed` argument can be
used to ensure deterministic (repeatable) random number generation.
Note that passing an integer as the `seed` value will produce the
same random values for each call. To generate different random
values for repeated calls, an instance of
`keras.random.SeedGenerator` must be provided as the `seed` value.
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global `keras.random.SeedGenerator` is not
supported. Therefore, during tracing the default value seed=None
will produce an error, and a `seed` argument must be provided.
Returns:
A 2-D tensor with (batch_size, num_samples).
Expand Down Expand Up @@ -94,14 +104,19 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
supported. If not specified, `keras.config.floatx()` is used,
which defaults to `float32` unless you configured it otherwise (via
`keras.config.set_floatx(float_dtype)`)
seed: A Python integer or instance of
`keras.random.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras.random.SeedGenerator`.
seed: Optional Python integer or instance of
`keras.random.SeedGenerator`.
By default, the `seed` argument is `None`, and an internal global
`keras.random.SeedGenerator` is used. The `seed` argument can be
used to ensure deterministic (repeatable) random number generation.
Note that passing an integer as the `seed` value will produce the
same random values for each call. To generate different random
values for repeated calls, an instance of
`keras.random.SeedGenerator` must be provided as the `seed` value.
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global `keras.random.SeedGenerator` is not
supported. Therefore, during tracing the default value seed=None
will produce an error, and a `seed` argument must be provided.
"""
if dtype and not backend.is_float_dtype(dtype):
raise ValueError(
Expand Down Expand Up @@ -133,14 +148,19 @@ def randint(shape, minval, maxval, dtype="int32", seed=None):
supported. If not specified, `keras.config.floatx()` is used,
which defaults to `float32` unless you configured it otherwise (via
`keras.config.set_floatx(float_dtype)`)
seed: A Python integer or instance of
`keras.random.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras.random.SeedGenerator`.
seed: Optional Python integer or instance of
`keras.random.SeedGenerator`.
By default, the `seed` argument is `None`, and an internal global
`keras.random.SeedGenerator` is used. The `seed` argument can be
used to ensure deterministic (repeatable) random number generation.
Note that passing an integer as the `seed` value will produce the
same random values for each call. To generate different random
values for repeated calls, an instance of
`keras.random.SeedGenerator` must be provided as the `seed` value.
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global `keras.random.SeedGenerator` is not
supported. Therefore, during tracing the default value seed=None
will produce an error, and a `seed` argument must be provided.
"""
if dtype and not backend.is_int_dtype(dtype):
raise ValueError(
Expand Down Expand Up @@ -169,14 +189,19 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
supported. If not specified, `keras.config.floatx()` is used,
which defaults to `float32` unless you configured it otherwise (via
`keras.config.set_floatx(float_dtype)`)
seed: A Python integer or instance of
`keras.random.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras.random.SeedGenerator`.
seed: Optional Python integer or instance of
`keras.random.SeedGenerator`.
By default, the `seed` argument is `None`, and an internal global
`keras.random.SeedGenerator` is used. The `seed` argument can be
used to ensure deterministic (repeatable) random number generation.
Note that passing an integer as the `seed` value will produce the
same random values for each call. To generate different random
values for repeated calls, an instance of
`keras.random.SeedGenerator` must be provided as the `seed` value.
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global `keras.random.SeedGenerator` is not
supported. Therefore, during tracing the default value seed=None
will produce an error, and a `seed` argument must be provided.
"""
return backend.random.truncated_normal(
shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed
Expand All @@ -198,14 +223,19 @@ def shuffle(x, axis=0, seed=None):
x: The tensor to be shuffled.
axis: An integer specifying the axis along which to shuffle. Defaults to
`0`.
seed: A Python integer or instance of
`keras.random.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras.random.SeedGenerator`.
seed: Optional Python integer or instance of
`keras.random.SeedGenerator`.
By default, the `seed` argument is `None`, and an internal global
`keras.random.SeedGenerator` is used. The `seed` argument can be
used to ensure deterministic (repeatable) random number generation.
Note that passing an integer as the `seed` value will produce the
same random values for each call. To generate different random
values for repeated calls, an instance of
`keras.random.SeedGenerator` must be provided as the `seed` value.
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global `keras.random.SeedGenerator` is not
supported. Therefore, during tracing the default value seed=None
will produce an error, and a `seed` argument must be provided.
"""
return backend.random.shuffle(x, axis=axis, seed=seed)

Expand All @@ -221,14 +251,19 @@ def gamma(shape, alpha, dtype=None, seed=None):
supported. If not specified, `keras.config.floatx()` is used,
which defaults to `float32` unless you configured it otherwise (via
`keras.config.set_floatx(float_dtype)`).
seed: A Python integer or instance of
`keras.random.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras.random.SeedGenerator`.
seed: Optional Python integer or instance of
`keras.random.SeedGenerator`.
By default, the `seed` argument is `None`, and an internal global
`keras.random.SeedGenerator` is used. The `seed` argument can be
used to ensure deterministic (repeatable) random number generation.
Note that passing an integer as the `seed` value will produce the
same random values for each call. To generate different random
values for repeated calls, an instance of
`keras.random.SeedGenerator` must be provided as the `seed` value.
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global `keras.random.SeedGenerator` is not
supported. Therefore, during tracing the default value seed=None
will produce an error, and a `seed` argument must be provided.
"""
return backend.random.gamma(shape, alpha=alpha, dtype=dtype, seed=seed)

Expand All @@ -251,14 +286,19 @@ def binomial(shape, counts, probabilities, dtype=None, seed=None):
supported. If not specified, `keras.config.floatx()` is used,
which defaults to `float32` unless you configured it otherwise (via
`keras.config.set_floatx(float_dtype)`).
seed: A Python integer or instance of
`keras.random.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras.random.SeedGenerator`.
seed: Optional Python integer or instance of
`keras.random.SeedGenerator`.
By default, the `seed` argument is `None`, and an internal global
`keras.random.SeedGenerator` is used. The `seed` argument can be
used to ensure deterministic (repeatable) random number generation.
Note that passing an integer as the `seed` value will produce the
same random values for each call. To generate different random
values for repeated calls, an instance of
`keras.random.SeedGenerator` must be provided as the `seed` value.
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global `keras.random.SeedGenerator` is not
supported. Therefore, during tracing the default value seed=None
will produce an error, and a `seed` argument must be provided.
"""
return backend.random.binomial(
shape,
Expand Down Expand Up @@ -286,14 +326,19 @@ def beta(shape, alpha, beta, dtype=None, seed=None):
supported. If not specified, `keras.config.floatx()` is used,
which defaults to `float32` unless you configured it otherwise (via
`keras.config.set_floatx(float_dtype)`).
seed: A Python integer or instance of
`keras.random.SeedGenerator`.
Used to make the behavior of the initializer
deterministic. Note that an initializer seeded with an integer
or None (unseeded) will produce the same random values
across multiple calls. To get different random values
across multiple calls, use as seed an instance
of `keras.random.SeedGenerator`.
seed: Optional Python integer or instance of
`keras.random.SeedGenerator`.
By default, the `seed` argument is `None`, and an internal global
`keras.random.SeedGenerator` is used. The `seed` argument can be
used to ensure deterministic (repeatable) random number generation.
Note that passing an integer as the `seed` value will produce the
same random values for each call. To generate different random
values for repeated calls, an instance of
`keras.random.SeedGenerator` must be provided as the `seed` value.
Remark concerning the JAX backend: When tracing functions with the
JAX backend the global `keras.random.SeedGenerator` is not
supported. Therefore, during tracing the default value seed=None
will produce an error, and a `seed` argument must be provided.
"""
return backend.random.beta(
shape=shape, alpha=alpha, beta=beta, dtype=dtype, seed=seed
Expand Down
26 changes: 18 additions & 8 deletions keras/src/random/seed_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,24 @@

@keras_export("keras.random.SeedGenerator")
class SeedGenerator:
"""Generates variable seeds upon each call to a RNG-using function.
In Keras, all RNG-using methods (such as `keras.random.normal()`)
are stateless, meaning that if you pass an integer seed to them
(such as `seed=42`), they will return the same values at each call.
In order to get different values at each call, you must use a
`SeedGenerator` instead as the seed argument. The `SeedGenerator`
object is stateful.
"""Generates variable seeds upon each call to a function generating
random numbers.
In Keras, all random number generators (such as
`keras.random.normal()`) are stateless, meaning that if you pass an
integer seed to them (such as `seed=42`), they will return the same
values for repeated calls. To get different values for each
call, a `SeedGenerator` providing the state of the random generator
has to be used.
Note that all the random number generators have a default seed of None,
which implies that an internal global SeedGenerator is used.
If you need to decouple the RNG from the global state you can provide
a local `StateGenerator` with either a deterministic or random initial
state.
Remark concerning the JAX backen: Note that the use of a local
`StateGenerator` as seed argument is required for JIT compilation of
RNG with the JAX backend, because the use of global state is not
supported.
Example:
Expand Down

0 comments on commit bcdb8e4

Please sign in to comment.