Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to save the model from minimal_training_example to be used in minimal_inference after fine-tuning it? #89

Open
Rachealthong opened this issue Aug 4, 2024 · 0 comments

Comments

@Rachealthong
Copy link

Rachealthong commented Aug 4, 2024

Thank you for the impressive work in releasing the minimal_training_example code. I an trying to save the model from minimal_training_example so it can be used in minimal_inference after fine-tuning it with my own dataset.

I encountered an issue when trying to save the model in minimal_training_example with this code.


import jax
import jax.numpy as jnp
import tensorflow as tf
from flax.training import train_state
from jax.experimental import jax2tf
import flax.linen as nn
import optax

# Assuming RT1X is your Flax model and state is your TrainState object

class TFModel(tf.Module):
    def __init__(self, state, model):
        super().__init__()
        self.state = state
        self.model = model

        # Convert JAX parameters to TensorFlow variables
        self.params_vars = tf.nest.map_structure(tf.Variable, self.state.params)
        
        # Keep the wrapped state as a flat list (needed in TensorFlow fine-tuning).
        self.vars = tf.nest.flatten(self.params_vars)
        
        # Convert the predict function
        self.predict_fn = jax2tf.convert(
            self.model.apply,
            polymorphic_shapes=[
                "dict(params=...)",  # the variables dictionary
                "dict(image=(batch_size, seq_len, height, width, channels), natural_language_embedding=(batch_size, seq_len, embedding_dim))",  # observation
                "dict(world_vector=(batch_size, seq_len, vector_dim), rotation_delta=(batch_size, seq_len, delta_dim), gripper_closedness_action=(batch_size, seq_len, action_dim), base_displacement_vertical_rotation=(batch_size, seq_len, rotation_dim), base_displacement_vector=(batch_size, seq_len, displacement_dim), terminate_episode=(batch_size, seq_len, terminate_dim))",  # action
                "()",  # train (boolean)
                "()",  # mutable (list or other structure)
                "dict(params=(), dropout=(), random=())"  # rngs dictionary
            ]
        )

    @tf.function(input_signature=[
        tf.TensorSpec(shape=(None, 15, 300, 300, 3), dtype=tf.float32),  # image
        tf.TensorSpec(shape=(None, 15, 512), dtype=tf.float32),  # natural_language_embedding
        tf.TensorSpec(shape=(None, 15, 3), dtype=tf.float32),  # world_vector
        tf.TensorSpec(shape=(None, 15, 3), dtype=tf.float32),  # rotation_delta
        tf.TensorSpec(shape=(None, 15, 1), dtype=tf.float32),  # gripper_closedness_action
        tf.TensorSpec(shape=(None, 15, 1), dtype=tf.float32),  # base_displacement_vertical_rotation
        tf.TensorSpec(shape=(None, 15, 2), dtype=tf.float32),  # base_displacement_vector
        tf.TensorSpec(shape=(None, 15, 3), dtype=tf.int32),  # terminate_episode
        tf.TensorSpec(shape=(2,), dtype=tf.uint32),  # params_rng
        tf.TensorSpec(shape=(2,), dtype=tf.uint32),  # dropout_rng
        tf.TensorSpec(shape=(2,), dtype=tf.uint32)   # random_rng
    ])
    def predict(self, image, natural_language_embedding, world_vector, rotation_delta, gripper_closedness_action, base_displacement_vertical_rotation, base_displacement_vector, terminate_episode, params_rng, dropout_rng, random_rng):
        obs = {
            "image": image,
            "natural_language_embedding": natural_language_embedding,
        }
        act = {
            "world_vector": world_vector,
            "rotation_delta": rotation_delta,
            "gripper_closedness_action": gripper_closedness_action,
            "base_displacement_vertical_rotation": base_displacement_vertical_rotation,
            "base_displacement_vector": base_displacement_vector,
            "terminate_episode": terminate_episode,
        }
        params = tf.nest.pack_sequence_as(self.state.params, self.vars)
        rngs = {
            "params": params_rng,
            "dropout": dropout_rng,
            "random": random_rng
        }
        return self.predict_fn(
            {'params': params},
            obs,
            act,
            train=False,
            mutable=[],
            rngs=rngs
        )

# Initialize the RT1 model and create a TrainState
rt1x_model = RT1(
    num_image_tokens=NUM_IMAGE_TOKENS,
    num_action_tokens=NUM_ACTION_TOKENS,
    layer_size=LAYER_SIZE,
    vocab_size=VOCAB_SIZE,
    use_token_learner=True,
    world_vector_range=(-2.0, 2.0)
)

# Dummy optimizer state for demonstration purposes
dummy_params = rt1x_model.init(jax.random.PRNGKey(0), {"image": jnp.ones((1, 15, 300, 300, 3)), "natural_language_embedding": jnp.ones((1, 15, 512))}, {"world_vector": jnp.ones((1, 15, 3)), "rotation_delta": jnp.ones((1, 15, 3)), "gripper_closedness_action": jnp.ones((1, 15, 1)), "base_displacement_vertical_rotation": jnp.ones((1, 15, 1)), "base_displacement_vector": jnp.ones((1, 15, 2)), "terminate_episode": jnp.ones((1, 15, 3), dtype=jnp.int32)}, train=False)["params"]
state = train_state.TrainState.create(
    apply_fn=rt1x_model.apply,
    params=dummy_params,
    tx=optax.adam(1e-3)
)

# Instantiate the TensorFlow Model
tf_model_instance = TFModel(state, rt1x_model)

# Save the TensorFlow model
tf.saved_model.save(tf_model_instance, "/home/user/Downloads/openxembodiment/saved_checkpoint")

however, I encountered ValueError.

ValueError: pytree structure error: different lengths of tuple at key path
        export.symbolic_args_specs shapes_specs
    At that key path, the prefix pytree export.symbolic_args_specs shapes_specs has a subtree of type tuple of length 6, but the full pytree has a subtree of the same type but of length 3.

Can anyone help me with this?

@Rachealthong Rachealthong changed the title ValueError vwhen saving the model How to save the model from minimal_training_example to be used in mininam_inference after fine-tuning it? Aug 5, 2024
@Rachealthong Rachealthong changed the title How to save the model from minimal_training_example to be used in mininam_inference after fine-tuning it? How to save the model from minimal_training_example to be used in minimal_inference after fine-tuning it? Aug 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant