Thank you for the impressive work in releasing the minimal_training_example code. I an trying to save the model from minimal_training_example so it can be used in minimal_inference after fine-tuning it with my own dataset.
I encountered an issue when trying to save the model in minimal_training_example with this code.
import jax
import jax.numpy as jnp
import tensorflow as tf
from flax.training import train_state
from jax.experimental import jax2tf
import flax.linen as nn
import optax
# Assuming RT1X is your Flax model and state is your TrainState object
class TFModel(tf.Module):
def __init__(self, state, model):
super().__init__()
self.state = state
self.model = model
# Convert JAX parameters to TensorFlow variables
self.params_vars = tf.nest.map_structure(tf.Variable, self.state.params)
# Keep the wrapped state as a flat list (needed in TensorFlow fine-tuning).
self.vars = tf.nest.flatten(self.params_vars)
# Convert the predict function
self.predict_fn = jax2tf.convert(
self.model.apply,
polymorphic_shapes=[
"dict(params=...)", # the variables dictionary
"dict(image=(batch_size, seq_len, height, width, channels), natural_language_embedding=(batch_size, seq_len, embedding_dim))", # observation
"dict(world_vector=(batch_size, seq_len, vector_dim), rotation_delta=(batch_size, seq_len, delta_dim), gripper_closedness_action=(batch_size, seq_len, action_dim), base_displacement_vertical_rotation=(batch_size, seq_len, rotation_dim), base_displacement_vector=(batch_size, seq_len, displacement_dim), terminate_episode=(batch_size, seq_len, terminate_dim))", # action
"()", # train (boolean)
"()", # mutable (list or other structure)
"dict(params=(), dropout=(), random=())" # rngs dictionary
]
)
@tf.function(input_signature=[
tf.TensorSpec(shape=(None, 15, 300, 300, 3), dtype=tf.float32), # image
tf.TensorSpec(shape=(None, 15, 512), dtype=tf.float32), # natural_language_embedding
tf.TensorSpec(shape=(None, 15, 3), dtype=tf.float32), # world_vector
tf.TensorSpec(shape=(None, 15, 3), dtype=tf.float32), # rotation_delta
tf.TensorSpec(shape=(None, 15, 1), dtype=tf.float32), # gripper_closedness_action
tf.TensorSpec(shape=(None, 15, 1), dtype=tf.float32), # base_displacement_vertical_rotation
tf.TensorSpec(shape=(None, 15, 2), dtype=tf.float32), # base_displacement_vector
tf.TensorSpec(shape=(None, 15, 3), dtype=tf.int32), # terminate_episode
tf.TensorSpec(shape=(2,), dtype=tf.uint32), # params_rng
tf.TensorSpec(shape=(2,), dtype=tf.uint32), # dropout_rng
tf.TensorSpec(shape=(2,), dtype=tf.uint32) # random_rng
])
def predict(self, image, natural_language_embedding, world_vector, rotation_delta, gripper_closedness_action, base_displacement_vertical_rotation, base_displacement_vector, terminate_episode, params_rng, dropout_rng, random_rng):
obs = {
"image": image,
"natural_language_embedding": natural_language_embedding,
}
act = {
"world_vector": world_vector,
"rotation_delta": rotation_delta,
"gripper_closedness_action": gripper_closedness_action,
"base_displacement_vertical_rotation": base_displacement_vertical_rotation,
"base_displacement_vector": base_displacement_vector,
"terminate_episode": terminate_episode,
}
params = tf.nest.pack_sequence_as(self.state.params, self.vars)
rngs = {
"params": params_rng,
"dropout": dropout_rng,
"random": random_rng
}
return self.predict_fn(
{'params': params},
obs,
act,
train=False,
mutable=[],
rngs=rngs
)
# Initialize the RT1 model and create a TrainState
rt1x_model = RT1(
num_image_tokens=NUM_IMAGE_TOKENS,
num_action_tokens=NUM_ACTION_TOKENS,
layer_size=LAYER_SIZE,
vocab_size=VOCAB_SIZE,
use_token_learner=True,
world_vector_range=(-2.0, 2.0)
)
# Dummy optimizer state for demonstration purposes
dummy_params = rt1x_model.init(jax.random.PRNGKey(0), {"image": jnp.ones((1, 15, 300, 300, 3)), "natural_language_embedding": jnp.ones((1, 15, 512))}, {"world_vector": jnp.ones((1, 15, 3)), "rotation_delta": jnp.ones((1, 15, 3)), "gripper_closedness_action": jnp.ones((1, 15, 1)), "base_displacement_vertical_rotation": jnp.ones((1, 15, 1)), "base_displacement_vector": jnp.ones((1, 15, 2)), "terminate_episode": jnp.ones((1, 15, 3), dtype=jnp.int32)}, train=False)["params"]
state = train_state.TrainState.create(
apply_fn=rt1x_model.apply,
params=dummy_params,
tx=optax.adam(1e-3)
)
# Instantiate the TensorFlow Model
tf_model_instance = TFModel(state, rt1x_model)
# Save the TensorFlow model
tf.saved_model.save(tf_model_instance, "/home/user/Downloads/openxembodiment/saved_checkpoint")
however, I encountered ValueError.
ValueError: pytree structure error: different lengths of tuple at key path
export.symbolic_args_specs shapes_specs
At that key path, the prefix pytree export.symbolic_args_specs shapes_specs has a subtree of type tuple of length 6, but the full pytree has a subtree of the same type but of length 3.
Thank you for the impressive work in releasing the minimal_training_example code. I an trying to save the model from minimal_training_example so it can be used in minimal_inference after fine-tuning it with my own dataset.
I encountered an issue when trying to save the model in minimal_training_example with this code.
however, I encountered ValueError.
Can anyone help me with this?