-
Notifications
You must be signed in to change notification settings - Fork 204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add optax.tree_utils.tree_batch_shape. #1161
Add optax.tree_utils.tree_batch_shape. #1161
Conversation
@rdyro How does this look? Happy to make any needed changes. |
Hey, sorry for the delayed response. Adding batch is definitely useful in practice, but I'm concerned about API uniformity with the other tree_like methods:
Can we go with a workaround for the time being (I realize it's not as concise 😞 ) a = {"a": jnp.ones(10), "c": jnp.zeros(4), "d": jnp.array(1.0)}
bs = (10,)
batch_tree = jax.tree.map(lambda x: jax.ShapeDtypeStruct(bs + x.shape, x.dtype), a)
optax.tree_utils.tree_random_like(random.key(0), batch_tree) |
How about adding a helper function like the following? def tree_batch(a, shape, dtype=None, sharding=None):
return jax.tree.map(lambda a: jax.ShapeDtypeStruct(
shape + a.shape,
dtype=dtype or a.dtype,
sharding=sharding or a.sharding,
), a) Then we can use from jax import random
from optax import tree_utils as otu
otu.tree_random_like(random.key(0), otu.tree_batch(tree, shape)) |
That sounds like a great idea! Would you be willing to contribute this addition? Can you call it something a little more verbose than |
8fa173b
to
3fe2b27
Compare
@rdyro Done. I thought it might be more helpful for other potential uses to keep the actual values instead of returning |
@rdyro Is the above ok? |
Using |
This commit adds a
batch_shape
argument tooptax.tree_utils.tree_random_like
.This is useful when sampling multiple perturbations, such as for evolution strategies.