Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example using vmap/pmap from Jax? #18570

Closed
asmith26 opened this issue Oct 7, 2023 · 8 comments
Closed

Example using vmap/pmap from Jax? #18570

asmith26 opened this issue Oct 7, 2023 · 8 comments

Comments

@asmith26
Copy link
Contributor

asmith26 commented Oct 7, 2023

I've written a custom Jax training loop, but unfortunately my script doesn't seem to be using very much GPU-memory. I've tried increasing the batch_size, but unfortunately that doesn't seem to make much difference. Hence I thought I'd try to increase the throughput with something like vmap or run some models in parallel (where applicable) with something like pmap (or even like flax.linen.vmap().

Are there any examples regarding how to do this? I've come across this guide, but this only appears to be for multiple devices.

I might be misunderstanding my problem, but many thanks for any help! :)

@asmith26 asmith26 changed the title Example using vmap from Jax? Example using vmap/pmap from Jax? Oct 7, 2023
@AakashKumarNain
Copy link
Contributor

Can you provide a minimal example?

@asmith26
Copy link
Contributor Author

asmith26 commented Oct 7, 2023

I'm trying to reproduce SAC with Keras Core Jax, and I'm following another example and am struggling to re-create the following vmap: https://github.com/vwxyzjn/cleanrl/blob/bbec22d1b60f54903ffaf4993df443f9001d0951/cleanrl/sac_continuous_action_jax.py#L128-L139

@AakashKumarNain
Copy link
Contributor

You should be able to use vmap and pmap flawlessly in Keras 3. If you can share the code you have written, I can provide some suggestions on how to do it

@asmith26
Copy link
Contributor Author

asmith26 commented Oct 8, 2023

Hi @AakashKumarNain, many thanks for your help, I really appreciate it.

Regarding a code example, to learn about Keras 3 + Jax, I thought I'd have a go at recreating this Flax SAC example (with these results as per vwxyzjn/cleanrl#300), here is my attempt.

The code is still a work in progress, but I've struggled to recreate this vmap (and instead just running the 2 Qnetworks sequentially), so if you are able to help with this this would be very helpful thanks.


[I'm also struggling to recreate this network, my attempt is here.]

@asmith26
Copy link
Contributor Author

asmith26 commented Oct 8, 2023

I've tried to create a minimal example of the thing I'm trying achieve - essentially I'm not sure how to stack and map over Keras 3 variables/models:

import os
from typing import Iterable

os.environ["KERAS_BACKEND"] = "jax"
import keras_core as keras
import jax
import jax.numpy as jnp

x = keras.ops.ones((2, 1))
inputs = keras.Input(shape=(1,))
outputs = keras.layers.Dense(1)(inputs)
model1 = keras.Model(inputs, outputs)
model2 = keras.Model(inputs, outputs)

# HOW STACK MODELS TO MAP OVER?
trainables = [model1.trainable_variables, model2.trainable_variables]
non_trainables = [model1.non_trainable_variables, model2.non_trainable_variables]

# THESE FAIL (arbitrarily use `model1.stateless_call`)
jax.vmap(model1.stateless_call, in_axes=(0, 0, None))(trainables, non_trainables, x)
jax.tree_util.tree_map(model1.stateless_call, trainables, non_trainables, jnp.stack([x, x]))

@AakashKumarNain
Copy link
Contributor

Thanks for the code examples. I wasn't able to dig into the first code. For the last code, here is how I ran the stateless call via vmap.

def build_model():
    inputs = layers.Input(shape=(1,), batch_size=2)
    x = layers.Dense(64, activation="relu")(inputs)
    x = layers.Dense(1)(x)
    model = Model(inputs, x, name="useless_model")
    return model

keras.backend.clear_session()
model = build_model()
model.summary()

trainable_vars = model.trainable_variables
non_trainable_vars = model.non_trainable_variables
inputs = jnp.array(np.random.rand(2, 1).astype(np.float32))

vmapped = jax.vmap(model.stateless_call, in_axes=(None, None, 0), out_axes=(0, None))
out = vmapped(trainable_vars, non_trainable_vars, inputs)

## output:
## (Array([[[-0.00094579]],
## 
##        [[-0.00088299]]], dtype=float32),
## [])

PS: IMO you are approaching the problem in a wrong fashion though. If I were in your place, I would rather vmap the function that is computationally expensive, and will keep everything else as it is.

@asmith26
Copy link
Contributor Author

asmith26 commented Oct 9, 2023

Thanks very much for the example and your help @AakashKumarNain. What I'm trying to achieve is running 2 different sets of model weights through the same model structure in parallel, e.g. something like (but this doesn't work):

trainables = [model1.trainable_variables, model2.trainable_variables]
non_trainables = [model1.non_trainable_variables, model2.non_trainable_variables]
jax.vmap(model.stateless_call, in_axes=(0, 0, None))(trainables, non_trainables, x)

# this produces the error:
ValueError: Argument `non_trainable_variables` must be a list of tensors corresponding 1:1 to Functional().non_trainable_variables. Received list with length 2, but expected 0 variables.

I think I'm creating the Jax container/Pytree incorrectly, not sure if I need to stack them in some other way?

@asmith26
Copy link
Contributor Author

Thanks for everyone's help, I think I'll close this and try and create more specific/cleaner examples as/if needed. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants