Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential error in documentation of keras.random.uniform. #20569

Open
roebel opened this issue Nov 30, 2024 · 3 comments · May be fixed by #20575
Open

Potential error in documentation of keras.random.uniform. #20569

roebel opened this issue Nov 30, 2024 · 3 comments · May be fixed by #20575
Assignees

Comments

@roebel
Copy link

roebel commented Nov 30, 2024

In the documentation of keras.ranom.uniform I find

seed: A Python integer or instance of
    `keras.random.SeedGenerator`.
    Used to make the behavior of the initializer
    deterministic. Note that an initializer seeded with an integer
    or None (unseeded) will produce the same random values
    across multiple calls. To get different random values
    across multiple calls, use as seed an instance
    of `keras.random.SeedGenerator`.

For me this implies that

keras.random.unform(shape=(), seed=None) == keras.random.unform(shape=(), seed=None)

Which however is not the case.

See the colab here to see that calling keras.random.unform(shape=(), seed=None) produces a sequence of random numbers.

On the other hand seed=1 really creates the same number for all calls in keras but not in Tensorflow, which is what I was expecting from the documentation.

Am I misunderstanding this or is the documentation about seed=None wrong?

@fchollet
Copy link
Contributor

It's a bit subtle because the behavior isn't quite the same when executing eagerly and when compiling.

seed=None gets converted to a random int seed. When executing eagerly, you get a different random int seed each time, so you see a different output each time. But when you compile, the int seed generated at function tracing time gets baked into the compiled graph, and so your compiled function will output the same value each time it's executed.

So if you want to generate different values at each execution even when compiling, you really need seed=SeedGenerator(). The documentation is accurate when running compiled functions. It is not quite accurate for eager execution though. If you can propose some edits to improve it, please do.

@roebel
Copy link
Author

roebel commented Dec 1, 2024

Ok, I can suggest a reformulation, but I'd first like to confirm that we agree on what is happening. Follwing your explanation, I looked this up in the code and I found here the implementation of the uniform function for the tensorflow backend

def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
    dtype = dtype or floatx()
    seed = _cast_seed(draw_seed(seed))
...

with draw_seed (here) using a global seed generator when seed is None:

...
    elif seed is None:
        return global_seed_generator().next(ordered=False)
...

If I read the global_seed_generator implementation correctly, it works the same for all backends, in compiled or eager mode, besides for the jax backend, where the use of a global SeedGenerator is not supported. So I would conclude that for seed=None a global seed generator is used, and random numbers will not be the same across calls.

If you are fine with this conclusion, then I would suggest changing

      Note that an initializer seeded with an integer
      or None (unseeded) will produce the same random values
      across multiple calls. To get different random values
      across multiple calls, use as seed an instance
      of `keras.random.SeedGenerator`.
      Note that passing an integer as a seed value will always produce the same random values
      across multiple calls. To get different random values
      across multiple calls, you need to use an instance
      of `keras.random.SeedGenerator` as seed value. In case you pass `None` as seed
      a global (internal) SeedGenerator will be used. Note however that the use of the global    
     `keras.random.SeedGenerator` is not supported when tracing functions for jit compilation 
      with the jax backend. 

I could create a pull request if this would be fine for you. I would then also propose to change the documentation of the SeedGenerator class saying that when seed is set to None the python random generator is used to create a seed value, and that the creation of reproducible results can then be achieved using keras.utils.set_random_seed.

@fchollet
Copy link
Contributor

fchollet commented Dec 1, 2024

You're right, the very first implementation of seed=None converted it to a random int, but at some point I changed it to use a global SeedGenerator (in order to avoid a common bug where people don't seed their RNG calls and expect different values at each call). However with the JAX backend doing this will fail (with a very explicit and clear error message) since any element of state must be tracked locally.

Your description is accurate.

@roebel roebel linked a pull request Dec 1, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants