-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Example using vmap
/pmap
from Jax?
#18570
Comments
vmap
from Jax?vmap
/pmap
from Jax?
Can you provide a minimal example? |
I'm trying to reproduce SAC with Keras Core Jax, and I'm following another example and am struggling to re-create the following |
You should be able to use |
Hi @AakashKumarNain, many thanks for your help, I really appreciate it. Regarding a code example, to learn about Keras 3 + Jax, I thought I'd have a go at recreating this Flax SAC example (with these results as per vwxyzjn/cleanrl#300), here is my attempt. The code is still a work in progress, but I've struggled to recreate this [I'm also struggling to recreate this network, my attempt is here.] |
I've tried to create a minimal example of the thing I'm trying achieve - essentially I'm not sure how to stack and map over Keras 3 variables/models: import os
from typing import Iterable
os.environ["KERAS_BACKEND"] = "jax"
import keras_core as keras
import jax
import jax.numpy as jnp
x = keras.ops.ones((2, 1))
inputs = keras.Input(shape=(1,))
outputs = keras.layers.Dense(1)(inputs)
model1 = keras.Model(inputs, outputs)
model2 = keras.Model(inputs, outputs)
# HOW STACK MODELS TO MAP OVER?
trainables = [model1.trainable_variables, model2.trainable_variables]
non_trainables = [model1.non_trainable_variables, model2.non_trainable_variables]
# THESE FAIL (arbitrarily use `model1.stateless_call`)
jax.vmap(model1.stateless_call, in_axes=(0, 0, None))(trainables, non_trainables, x)
jax.tree_util.tree_map(model1.stateless_call, trainables, non_trainables, jnp.stack([x, x])) |
Thanks for the code examples. I wasn't able to dig into the first code. For the last code, here is how I ran the stateless call via def build_model():
inputs = layers.Input(shape=(1,), batch_size=2)
x = layers.Dense(64, activation="relu")(inputs)
x = layers.Dense(1)(x)
model = Model(inputs, x, name="useless_model")
return model
keras.backend.clear_session()
model = build_model()
model.summary()
trainable_vars = model.trainable_variables
non_trainable_vars = model.non_trainable_variables
inputs = jnp.array(np.random.rand(2, 1).astype(np.float32))
vmapped = jax.vmap(model.stateless_call, in_axes=(None, None, 0), out_axes=(0, None))
out = vmapped(trainable_vars, non_trainable_vars, inputs)
## output:
## (Array([[[-0.00094579]],
##
## [[-0.00088299]]], dtype=float32),
## []) PS: IMO you are approaching the problem in a wrong fashion though. If I were in your place, I would rather |
Thanks very much for the example and your help @AakashKumarNain. What I'm trying to achieve is running 2 different sets of model weights through the same model structure in parallel, e.g. something like (but this doesn't work): trainables = [model1.trainable_variables, model2.trainable_variables]
non_trainables = [model1.non_trainable_variables, model2.non_trainable_variables]
jax.vmap(model.stateless_call, in_axes=(0, 0, None))(trainables, non_trainables, x)
# this produces the error:
ValueError: Argument `non_trainable_variables` must be a list of tensors corresponding 1:1 to Functional().non_trainable_variables. Received list with length 2, but expected 0 variables. I think I'm creating the Jax container/Pytree incorrectly, not sure if I need to stack them in some other way? |
Thanks for everyone's help, I think I'll close this and try and create more specific/cleaner examples as/if needed. Thanks again! |
I've written a custom Jax training loop, but unfortunately my script doesn't seem to be using very much GPU-memory. I've tried increasing the
batch_size
, but unfortunately that doesn't seem to make much difference. Hence I thought I'd try to increase the throughput with something like vmap or run some models in parallel (where applicable) with something like pmap (or even likeflax.linen.vmap()
.Are there any examples regarding how to do this? I've come across this guide, but this only appears to be for multiple devices.
I might be misunderstanding my problem, but many thanks for any help! :)
The text was updated successfully, but these errors were encountered: