Skip to content

Commit

Permalink
add bathed short simulation
Browse files Browse the repository at this point in the history
  • Loading branch information
simeon-ned committed Sep 26, 2024
1 parent c8be148 commit d6c3714
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 39 deletions.
8 changes: 4 additions & 4 deletions examples/opti_loss.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@
}
],
"source": [
"theta2logchol(true_parameters)\n"
"theta2logchol(true_parameters)"
]
},
{
Expand Down Expand Up @@ -413,7 +413,7 @@
"plt.title(\"Losses over time\")\n",
"plt.fill_between(jnp.arange(horizon, len(log)), results[1].min(axis=0), results[1].max(axis=0), alpha=0.5)\n",
"plt.plot(jnp.arange(horizon, len(log)), results[1].mean(axis=0))\n",
"plt.show()\n"
"plt.show()"
]
},
{
Expand Down Expand Up @@ -469,7 +469,7 @@
"for i in range(parameters_history.shape[0]):\n",
" history.append(vmap_logchol2theta(parameters_history[i]))\n",
"\n",
"history = jnp.array(history)\n"
"history = jnp.array(history)"
]
},
{
Expand Down Expand Up @@ -523,7 +523,7 @@
" # plt.fill_between(jnp.arange(horizon, len(log)), min_parameter, max_parameter, alpha=0.5)\n",
" plt.plot(jnp.arange(horizon, len(log)), current_parameter.mean(axis=0))\n",
" plt.plot(jnp.arange(horizon, len(log)), jnp.repeat(true_parameters[i], len(log) - horizon), label=\"True\")\n",
" plt.legend()\n"
" plt.legend()"
]
},
{
Expand Down
145 changes: 112 additions & 33 deletions examples/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,6 @@
import optax


# Initialize random key
key = jax.random.PRNGKey(0)

# Load the model
MJCF_PATH = "../data/models/pendulum/pendulum.xml"
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
data = mujoco.MjData(model)
model.opt.integrator = 1

# Setting up constraint solver to ensure differentiability and faster simulations
model.opt.solver = 2 # 2 corresponds to Newton solver
model.opt.iterations = 2
model.opt.ls_iterations = 10

mjx_model = mjx.put_model(model)

# Load test data
TEST_DATA_PATH = "../data/trajectories/pendulum/free_fall_2.csv"
data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500)
timespan = data_array[:, 0] - data_array[0, 0]
sampling = np.mean(np.diff(timespan))
angle = data_array[:, 1]
velocity = data_array[:, 2]
control = data_array[:, 3]

model.opt.timestep = sampling


@jax.jit
def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
"""Map new parameters to the model."""
Expand Down Expand Up @@ -76,19 +48,119 @@ def step_fn(state, control):
return states


# Initialize random key
key = jax.random.PRNGKey(0)

# Load the model
MJCF_PATH = "../data/models/pendulum/pendulum.xml"
model = mujoco.MjModel.from_xml_path(MJCF_PATH)
data = mujoco.MjData(model)
model.opt.integrator = 1

# Setting up constraint solver to ensure differentiability and faster simulations
model.opt.solver = 2 # 2 corresponds to Newton solver
model.opt.iterations = 2
model.opt.ls_iterations = 10

mjx_model = mjx.put_model(model)

# Load test data
TEST_DATA_PATH = "../data/trajectories/pendulum/free_fall_2.csv"
data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500)
timespan = data_array[:, 0] - data_array[0, 0]
sampling = np.mean(np.diff(timespan))
angle = data_array[:, 1]
velocity = data_array[:, 2]
control = data_array[:, 3]

model.opt.timestep = sampling

HORIZON = 100
N_INTERVALS = len(timespan) // HORIZON - 1
timespan = timespan[: N_INTERVALS * HORIZON]
angle = angle[: N_INTERVALS * HORIZON]
velocity = velocity[: N_INTERVALS * HORIZON]
control = control[: N_INTERVALS * HORIZON]

# Prepare data for simulation and optimization
initial_state = jnp.array([angle[0], velocity[0]])
true_trajectory = jnp.column_stack((angle, velocity))
control_inputs = jnp.array(control)

interval_true_trajectory = true_trajectory[::HORIZON]
interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON)

# Get default parameters from the model
default_parameters = jnp.concatenate(
[theta2logchol(get_dynamic_parameters(mjx_model, 1)), mjx_model.dof_damping, mjx_model.dof_frictionloss]
)

# Simulation with XML parameters
xml_trajectory = rollout_trajectory(default_parameters, mjx_model, initial_state, control_inputs)
# //////////////////////////////////////
# SIMULATION BATCHES: THIS WILL BE HANDY IN OPTIMIZATION

# Vectorize over both initial states and control inputs
batched_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(None, None, 0, 0)))

# Create a batch of initial states
key, subkey = jax.random.split(key)
batch_initial_states = jax.random.uniform(subkey, (N_INTERVALS, 2), minval=-0.1, maxval=0.1) + initial_state
# Create a batch of control input sequences
key, subkey = jax.random.split(key)
batch_control_inputs = jax.random.normal(subkey, (N_INTERVALS, HORIZON)) * 0.1 # + control_inputs
# Run warm up for batched rollout
t1 = perf_counter()
batched_trajectories = batched_rollout(default_parameters, mjx_model, batch_initial_states, batch_control_inputs)
t2 = perf_counter()
print(f"Batch simulation time: {t2 - t1} seconds")

# Run batched rollout on shor horizon data from pendulum
interval_initial_states = true_trajectory[::HORIZON]
interval_controls = control_inputs.reshape(N_INTERVALS, HORIZON)
t1 = perf_counter()
batched_states_trajectories = batched_rollout(
default_parameters * 0.7, mjx_model, interval_initial_states, interval_controls
)
t2 = perf_counter()
print(f"Batch simulation time: {t2 - t1} seconds")

batched_states_trajectories = np.array(batched_states_trajectories).reshape(N_INTERVALS * HORIZON, 2)

# Plotting simulation results for batсhed state trajectories
plt.figure(figsize=(10, 5))

plt.subplot(2, 2, 1)
plt.plot(timespan, angle, label="Actual Angle", color="black", linestyle="dashed", linewidth=2)
plt.plot(timespan, batched_states_trajectories[:, 0], alpha=0.5, color="blue", label="Simulated Angle")
plt.ylabel("Angle (rad)")
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
plt.legend()
plt.title("Pendulum Dynamics - Bathed State Trajectories")

plt.subplot(2, 2, 3)
plt.plot(timespan, velocity, label="Actual Velocity", color="black", linestyle="dashed", linewidth=2)
plt.plot(timespan, batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated Velocity")
plt.xlabel("Time (s)")
plt.ylabel("Velocity (rad/s)")
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
plt.legend()

# Add phase portrait
plt.subplot(1, 2, 2)
plt.plot(angle, velocity, label="Actual", color="black", linestyle="dashed", linewidth=2)
plt.plot(
batched_states_trajectories[:, 0], batched_states_trajectories[:, 1], alpha=0.5, color="blue", label="Simulated"
)
plt.xlabel("Angle (rad)")
plt.ylabel("Angular Velocity (rad/s)")
plt.title("Phase Portrait")
plt.grid(color="black", linestyle="--", linewidth=1.0, alpha=0.4)
plt.legend()

plt.tight_layout()
plt.show()

# //////////////////////////////////////////////////
# PARAMETRIC BATCHES
# Create a batch of 200 randomized parameters
num_batches = 200
key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)
Expand All @@ -115,13 +187,16 @@ def step_fn(state, control):
batch_parameters = batch_parameters.at[:, -2].set(randomized_damping)
batch_parameters = batch_parameters.at[:, -1].set(randomized_dry_friction)


# Define a batched version of rollout_trajectory using vmap
batched_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(0, None, None, None)))
batched_parameters_rollout = jax.jit(jax.vmap(rollout_trajectory, in_axes=(0, None, None, None)))

# Simulation with XML parameters
xml_trajectory = rollout_trajectory(default_parameters, mjx_model, initial_state, control_inputs)

# Simulate trajectories with randomized parameters using vmap
t1 = perf_counter()
randomized_trajectories = batched_rollout(batch_parameters, mjx_model, initial_state, control_inputs)
randomized_trajectories = batched_parameters_rollout(batch_parameters, mjx_model, initial_state, control_inputs)
t2 = perf_counter()

print(f"Simulation with randomized parameters using vmap took {t2-t1:.2f} seconds.")
Expand Down Expand Up @@ -187,10 +262,14 @@ def step_fn(state, control):

# Simulate trajectories with randomized parameters using vmap
t1 = perf_counter()
randomized_trajectories = batched_rollout(batch_parameters, mjx_model, initial_state, control_inputs)
randomized_trajectories = batched_parameters_rollout(batch_parameters, mjx_model, initial_state, control_inputs)
t2 = perf_counter()
print(f"Simulation with randomized parameters using vmap took {t2-t1:.2f} seconds.")


# TODO: OPTIMIZATION


# Optimization

# # Error function
Expand Down
10 changes: 8 additions & 2 deletions mujoco_sysid/mjx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,13 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
return model


# @jax.jit(static_argnames=['parameters_map'])
def parametric_step(
parameters: jnp.ndarray, parameters_map: Callable, model: mjx.Model, state: jnp.ndarray, control: jnp.ndarray
parameters: jnp.ndarray,
model: mjx.Model,
state: jnp.ndarray,
control: jnp.ndarray,
parameters_map: Callable,
) -> jnp.ndarray:
"""
Perform a step with new parameter mapping.
Expand All @@ -150,12 +155,13 @@ def parametric_step(
return jnp.concatenate([data.qpos, data.qvel])


# @jax.jit(static_argnames=['parameters_map'])
def rollout_trajectory(
parameters: jnp.ndarray,
parameters_map: Callable,
model: mjx.Model,
initial_state: jnp.ndarray,
control_inputs: jnp.ndarray,
parameters_map: Callable,
) -> jnp.ndarray:
"""
Rollout a trajectory given parameters, initial state, and control inputs.
Expand Down

0 comments on commit d6c3714

Please sign in to comment.