From bcdb8e452a18c0c6038db20802cdb83e9046ca58 Mon Sep 17 00:00:00 2001 From: roebel Date: Mon, 2 Dec 2024 19:11:26 +0100 Subject: [PATCH] Random seed doc (#20575) * Fixed error in doc of random number generators concerning seed argument. * Update class documentation for SeedGenerator Clarify the facts that - a global SeedGenerator is used by all random number generating functions in keras, - a SeedGenerator is required for jit compilation with the JAX backend. * Minor reformulation. * Refined remark on JAX and tracing. * Fixed column length. * Fixed line length of documentation. * Reformatted with black. * Reformatted with black. * Still some lines too long? * Another long column that was intrduced by black. --- keras/src/random/random.py | 189 ++++++++++++++++++----------- keras/src/random/seed_generator.py | 26 ++-- 2 files changed, 135 insertions(+), 80 deletions(-) diff --git a/keras/src/random/random.py b/keras/src/random/random.py index 705411bcb72..33d5d593dbf 100644 --- a/keras/src/random/random.py +++ b/keras/src/random/random.py @@ -15,14 +15,19 @@ def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): supported. If not specified, `keras.config.floatx()` is used, which defaults to `float32` unless you configured it otherwise (via `keras.config.set_floatx(float_dtype)`). - seed: A Python integer or instance of - `keras.random.SeedGenerator`. - Used to make the behavior of the initializer - deterministic. Note that an initializer seeded with an integer - or None (unseeded) will produce the same random values - across multiple calls. To get different random values - across multiple calls, use as seed an instance - of `keras.random.SeedGenerator`. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. """ return backend.random.normal( shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed @@ -51,14 +56,19 @@ def categorical(logits, num_samples, dtype="int32", seed=None): row of the input. This will be the second dimension of the output tensor's shape. dtype: Optional dtype of the output tensor. - seed: A Python integer or instance of - `keras.random.SeedGenerator`. - Used to make the behavior of the initializer - deterministic. Note that an initializer seeded with an integer - or None (unseeded) will produce the same random values - across multiple calls. To get different random values - across multiple calls, use as seed an instance - of `keras.random.SeedGenerator`. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. Returns: A 2-D tensor with (batch_size, num_samples). @@ -94,14 +104,19 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None): supported. If not specified, `keras.config.floatx()` is used, which defaults to `float32` unless you configured it otherwise (via `keras.config.set_floatx(float_dtype)`) - seed: A Python integer or instance of - `keras.random.SeedGenerator`. - Used to make the behavior of the initializer - deterministic. Note that an initializer seeded with an integer - or None (unseeded) will produce the same random values - across multiple calls. To get different random values - across multiple calls, use as seed an instance - of `keras.random.SeedGenerator`. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. """ if dtype and not backend.is_float_dtype(dtype): raise ValueError( @@ -133,14 +148,19 @@ def randint(shape, minval, maxval, dtype="int32", seed=None): supported. If not specified, `keras.config.floatx()` is used, which defaults to `float32` unless you configured it otherwise (via `keras.config.set_floatx(float_dtype)`) - seed: A Python integer or instance of - `keras.random.SeedGenerator`. - Used to make the behavior of the initializer - deterministic. Note that an initializer seeded with an integer - or None (unseeded) will produce the same random values - across multiple calls. To get different random values - across multiple calls, use as seed an instance - of `keras.random.SeedGenerator`. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. """ if dtype and not backend.is_int_dtype(dtype): raise ValueError( @@ -169,14 +189,19 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): supported. If not specified, `keras.config.floatx()` is used, which defaults to `float32` unless you configured it otherwise (via `keras.config.set_floatx(float_dtype)`) - seed: A Python integer or instance of - `keras.random.SeedGenerator`. - Used to make the behavior of the initializer - deterministic. Note that an initializer seeded with an integer - or None (unseeded) will produce the same random values - across multiple calls. To get different random values - across multiple calls, use as seed an instance - of `keras.random.SeedGenerator`. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. """ return backend.random.truncated_normal( shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed @@ -198,14 +223,19 @@ def shuffle(x, axis=0, seed=None): x: The tensor to be shuffled. axis: An integer specifying the axis along which to shuffle. Defaults to `0`. - seed: A Python integer or instance of - `keras.random.SeedGenerator`. - Used to make the behavior of the initializer - deterministic. Note that an initializer seeded with an integer - or None (unseeded) will produce the same random values - across multiple calls. To get different random values - across multiple calls, use as seed an instance - of `keras.random.SeedGenerator`. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. """ return backend.random.shuffle(x, axis=axis, seed=seed) @@ -221,14 +251,19 @@ def gamma(shape, alpha, dtype=None, seed=None): supported. If not specified, `keras.config.floatx()` is used, which defaults to `float32` unless you configured it otherwise (via `keras.config.set_floatx(float_dtype)`). - seed: A Python integer or instance of - `keras.random.SeedGenerator`. - Used to make the behavior of the initializer - deterministic. Note that an initializer seeded with an integer - or None (unseeded) will produce the same random values - across multiple calls. To get different random values - across multiple calls, use as seed an instance - of `keras.random.SeedGenerator`. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. """ return backend.random.gamma(shape, alpha=alpha, dtype=dtype, seed=seed) @@ -251,14 +286,19 @@ def binomial(shape, counts, probabilities, dtype=None, seed=None): supported. If not specified, `keras.config.floatx()` is used, which defaults to `float32` unless you configured it otherwise (via `keras.config.set_floatx(float_dtype)`). - seed: A Python integer or instance of - `keras.random.SeedGenerator`. - Used to make the behavior of the initializer - deterministic. Note that an initializer seeded with an integer - or None (unseeded) will produce the same random values - across multiple calls. To get different random values - across multiple calls, use as seed an instance - of `keras.random.SeedGenerator`. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. """ return backend.random.binomial( shape, @@ -286,14 +326,19 @@ def beta(shape, alpha, beta, dtype=None, seed=None): supported. If not specified, `keras.config.floatx()` is used, which defaults to `float32` unless you configured it otherwise (via `keras.config.set_floatx(float_dtype)`). - seed: A Python integer or instance of - `keras.random.SeedGenerator`. - Used to make the behavior of the initializer - deterministic. Note that an initializer seeded with an integer - or None (unseeded) will produce the same random values - across multiple calls. To get different random values - across multiple calls, use as seed an instance - of `keras.random.SeedGenerator`. + seed: Optional Python integer or instance of + `keras.random.SeedGenerator`. + By default, the `seed` argument is `None`, and an internal global + `keras.random.SeedGenerator` is used. The `seed` argument can be + used to ensure deterministic (repeatable) random number generation. + Note that passing an integer as the `seed` value will produce the + same random values for each call. To generate different random + values for repeated calls, an instance of + `keras.random.SeedGenerator` must be provided as the `seed` value. + Remark concerning the JAX backend: When tracing functions with the + JAX backend the global `keras.random.SeedGenerator` is not + supported. Therefore, during tracing the default value seed=None + will produce an error, and a `seed` argument must be provided. """ return backend.random.beta( shape=shape, alpha=alpha, beta=beta, dtype=dtype, seed=seed diff --git a/keras/src/random/seed_generator.py b/keras/src/random/seed_generator.py index 841f729f3b4..ad429e90bf3 100644 --- a/keras/src/random/seed_generator.py +++ b/keras/src/random/seed_generator.py @@ -11,14 +11,24 @@ @keras_export("keras.random.SeedGenerator") class SeedGenerator: - """Generates variable seeds upon each call to a RNG-using function. - - In Keras, all RNG-using methods (such as `keras.random.normal()`) - are stateless, meaning that if you pass an integer seed to them - (such as `seed=42`), they will return the same values at each call. - In order to get different values at each call, you must use a - `SeedGenerator` instead as the seed argument. The `SeedGenerator` - object is stateful. + """Generates variable seeds upon each call to a function generating + random numbers. + + In Keras, all random number generators (such as + `keras.random.normal()`) are stateless, meaning that if you pass an + integer seed to them (such as `seed=42`), they will return the same + values for repeated calls. To get different values for each + call, a `SeedGenerator` providing the state of the random generator + has to be used. + Note that all the random number generators have a default seed of None, + which implies that an internal global SeedGenerator is used. + If you need to decouple the RNG from the global state you can provide + a local `StateGenerator` with either a deterministic or random initial + state. + Remark concerning the JAX backen: Note that the use of a local + `StateGenerator` as seed argument is required for JIT compilation of + RNG with the JAX backend, because the use of global state is not + supported. Example: