Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optax.tree_utils.tree_batch_shape. #1161

Conversation

carlosgmartin
Copy link
Contributor

@carlosgmartin carlosgmartin commented Dec 26, 2024

This commit adds a batch_shape argument to optax.tree_utils.tree_random_like.

This is useful when sampling multiple perturbations, such as for evolution strategies.

@carlosgmartin
Copy link
Contributor Author

@rdyro How does this look? Happy to make any needed changes.

@rdyro
Copy link
Collaborator

rdyro commented Jan 27, 2025

Hey, sorry for the delayed response. Adding batch is definitely useful in practice, but I'm concerned about API uniformity with the other tree_like methods:

tree_full_like
tree_ones_like
tree_random_like
tree_split_key_like
tree_zeros_like

Can we go with a workaround for the time being (I realize it's not as concise 😞 )

a = {"a": jnp.ones(10), "c": jnp.zeros(4), "d": jnp.array(1.0)}
bs = (10,)
batch_tree = jax.tree.map(lambda x: jax.ShapeDtypeStruct(bs + x.shape, x.dtype), a)
optax.tree_utils.tree_random_like(random.key(0), batch_tree)

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Jan 27, 2025

How about adding a helper function like the following?

def tree_batch(a, shape, dtype=None, sharding=None):
  return jax.tree.map(lambda a: jax.ShapeDtypeStruct(
    shape + a.shape,
    dtype=dtype or a.dtype,
    sharding=sharding or a.sharding,
  ), a)

Then we can use

from jax import random
from optax import tree_utils as otu
otu.tree_random_like(random.key(0), otu.tree_batch(tree, shape))

@rdyro
Copy link
Collaborator

rdyro commented Jan 27, 2025

That sounds like a great idea! Would you be willing to contribute this addition?

Can you call it something a little more verbose than tree_batch maybe tree_batch_shape, but I'm open to better ideas?

@carlosgmartin carlosgmartin force-pushed the tree_random_like_batch_shape branch from 8fa173b to 3fe2b27 Compare January 28, 2025 19:32
@carlosgmartin carlosgmartin changed the title Add batch_shape argument to optax.tree_utils.tree_random_like. Add optax.tree_utils.tree_batch_shape. Jan 28, 2025
@carlosgmartin
Copy link
Contributor Author

@rdyro Done. I thought it might be more helpful for other potential uses to keep the actual values instead of returning ShapeDtypeStructs, so I've switched to using broadcast_to instead. (This uses zero-strides, so it doesn't need additional memory.)

@carlosgmartin
Copy link
Contributor Author

@rdyro Is the above ok?

@rdyro
Copy link
Collaborator

rdyro commented Feb 2, 2025

@rdyro Is the above ok?

Using broadcast_to makes this function more general, nice!

@copybara-service copybara-service bot merged commit b51d9a8 into google-deepmind:main Feb 3, 2025
12 checks passed
@carlosgmartin carlosgmartin deleted the tree_random_like_batch_shape branch February 3, 2025 19:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants