You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thank you for the impressive work in releasing the minimal_training_example code. I an trying to save the model from minimal_training_example so it can be used in minimal_inference after fine-tuning it with my own dataset.
I encountered an issue when trying to save the model in minimal_training_example with this code.
import jax
import jax.numpy as jnp
import tensorflow as tf
from flax.training import train_state
from jax.experimental import jax2tf
import flax.linen as nn
import optax
# Assuming RT1X is your Flax model and state is your TrainState object
class TFModel(tf.Module):
def __init__(self, state, model):
super().__init__()
self.state = state
self.model = model
# Convert JAX parameters to TensorFlow variables
self.params_vars = tf.nest.map_structure(tf.Variable, self.state.params)
# Keep the wrapped state as a flat list (needed in TensorFlow fine-tuning).
self.vars = tf.nest.flatten(self.params_vars)
# Convert the predict function
self.predict_fn = jax2tf.convert(
self.model.apply,
polymorphic_shapes=[
"dict(params=...)", # the variables dictionary
"dict(image=(batch_size, seq_len, height, width, channels), natural_language_embedding=(batch_size, seq_len, embedding_dim))", # observation
"dict(world_vector=(batch_size, seq_len, vector_dim), rotation_delta=(batch_size, seq_len, delta_dim), gripper_closedness_action=(batch_size, seq_len, action_dim), base_displacement_vertical_rotation=(batch_size, seq_len, rotation_dim), base_displacement_vector=(batch_size, seq_len, displacement_dim), terminate_episode=(batch_size, seq_len, terminate_dim))", # action
"()", # train (boolean)
"()", # mutable (list or other structure)
"dict(params=(), dropout=(), random=())" # rngs dictionary
]
)
@tf.function(input_signature=[
tf.TensorSpec(shape=(None, 15, 300, 300, 3), dtype=tf.float32), # image
tf.TensorSpec(shape=(None, 15, 512), dtype=tf.float32), # natural_language_embedding
tf.TensorSpec(shape=(None, 15, 3), dtype=tf.float32), # world_vector
tf.TensorSpec(shape=(None, 15, 3), dtype=tf.float32), # rotation_delta
tf.TensorSpec(shape=(None, 15, 1), dtype=tf.float32), # gripper_closedness_action
tf.TensorSpec(shape=(None, 15, 1), dtype=tf.float32), # base_displacement_vertical_rotation
tf.TensorSpec(shape=(None, 15, 2), dtype=tf.float32), # base_displacement_vector
tf.TensorSpec(shape=(None, 15, 3), dtype=tf.int32), # terminate_episode
tf.TensorSpec(shape=(2,), dtype=tf.uint32), # params_rng
tf.TensorSpec(shape=(2,), dtype=tf.uint32), # dropout_rng
tf.TensorSpec(shape=(2,), dtype=tf.uint32) # random_rng
])
def predict(self, image, natural_language_embedding, world_vector, rotation_delta, gripper_closedness_action, base_displacement_vertical_rotation, base_displacement_vector, terminate_episode, params_rng, dropout_rng, random_rng):
obs = {
"image": image,
"natural_language_embedding": natural_language_embedding,
}
act = {
"world_vector": world_vector,
"rotation_delta": rotation_delta,
"gripper_closedness_action": gripper_closedness_action,
"base_displacement_vertical_rotation": base_displacement_vertical_rotation,
"base_displacement_vector": base_displacement_vector,
"terminate_episode": terminate_episode,
}
params = tf.nest.pack_sequence_as(self.state.params, self.vars)
rngs = {
"params": params_rng,
"dropout": dropout_rng,
"random": random_rng
}
return self.predict_fn(
{'params': params},
obs,
act,
train=False,
mutable=[],
rngs=rngs
)
# Initialize the RT1 model and create a TrainState
rt1x_model = RT1(
num_image_tokens=NUM_IMAGE_TOKENS,
num_action_tokens=NUM_ACTION_TOKENS,
layer_size=LAYER_SIZE,
vocab_size=VOCAB_SIZE,
use_token_learner=True,
world_vector_range=(-2.0, 2.0)
)
# Dummy optimizer state for demonstration purposes
dummy_params = rt1x_model.init(jax.random.PRNGKey(0), {"image": jnp.ones((1, 15, 300, 300, 3)), "natural_language_embedding": jnp.ones((1, 15, 512))}, {"world_vector": jnp.ones((1, 15, 3)), "rotation_delta": jnp.ones((1, 15, 3)), "gripper_closedness_action": jnp.ones((1, 15, 1)), "base_displacement_vertical_rotation": jnp.ones((1, 15, 1)), "base_displacement_vector": jnp.ones((1, 15, 2)), "terminate_episode": jnp.ones((1, 15, 3), dtype=jnp.int32)}, train=False)["params"]
state = train_state.TrainState.create(
apply_fn=rt1x_model.apply,
params=dummy_params,
tx=optax.adam(1e-3)
)
# Instantiate the TensorFlow Model
tf_model_instance = TFModel(state, rt1x_model)
# Save the TensorFlow model
tf.saved_model.save(tf_model_instance, "/home/user/Downloads/openxembodiment/saved_checkpoint")
however, I encountered ValueError.
ValueError: pytree structure error: different lengths of tuple at key path
export.symbolic_args_specs shapes_specs
At that key path, the prefix pytree export.symbolic_args_specs shapes_specs has a subtree of type tuple of length 6, but the full pytree has a subtree of the same type but of length 3.
Can anyone help me with this?
The text was updated successfully, but these errors were encountered:
Rachealthong
changed the title
ValueError vwhen saving the model
How to save the model from minimal_training_example to be used in mininam_inference after fine-tuning it?
Aug 5, 2024
Rachealthong
changed the title
How to save the model from minimal_training_example to be used in mininam_inference after fine-tuning it?
How to save the model from minimal_training_example to be used in minimal_inference after fine-tuning it?
Aug 5, 2024
Thank you for the impressive work in releasing the minimal_training_example code. I an trying to save the model from minimal_training_example so it can be used in minimal_inference after fine-tuning it with my own dataset.
I encountered an issue when trying to save the model in minimal_training_example with this code.
however, I encountered ValueError.
Can anyone help me with this?
The text was updated successfully, but these errors were encountered: