diff --git a/examples/pendulum/01_learning.py b/examples/pendulum/01_learning.py index 9114c92..741abfb 100644 --- a/examples/pendulum/01_learning.py +++ b/examples/pendulum/01_learning.py @@ -60,7 +60,7 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model: # Load test data LEARNING_DATA_PATH = "data/free_fall_2.csv" -data_array = np.genfromtxt(LEARNING_DATA_PATH, delimiter=",", skip_header=100, skip_footer=1000) +data_array = np.genfromtxt(LEARNING_DATA_PATH, delimiter=",", skip_header=10, skip_footer=2500) timespan = data_array[:, 0] - data_array[0, 0] sampling = np.mean(np.diff(timespan)) angle = data_array[:, 1] @@ -90,21 +90,6 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model: ) -@jax.jit -def rollout_errors(parameters, states, controls): - # TODO: Use the full trajecttory in shouting not only las point - interval_initial_states = states[::HORIZON] - interval_terminal_states = states[HORIZON + 1 :][::HORIZON] - interval_controls = jnp.reshape(controls, (N_INTERVALS, HORIZON)) - batched_rollout = jax.vmap(rollout_trajectory, in_axes=(None, None, 0, 0)) - batched_states_trajectories = batched_rollout(parameters, mjx_model, interval_initial_states, interval_controls) - predicted_terminal_points = batched_states_trajectories[:, -1, :] - loss = jnp.mean( - optax.l2_loss(predicted_terminal_points[:-1], interval_terminal_states) - ) # + 0.05*jnp.mean(optax.huber_loss(parameters, jnp.zeros_like(parameters))) - return loss - - @jax.jit def rollout_errors(parameters, states, controls): # TODO: Use the full trajecttory in shouting not only las point diff --git a/examples/pendulum/02_simulation_error.py b/examples/pendulum/02_simulation_error.py index 500eaa6..f97978f 100644 --- a/examples/pendulum/02_simulation_error.py +++ b/examples/pendulum/02_simulation_error.py @@ -53,8 +53,8 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model: mjx_model = mjx.put_model(model) # Load test data -TEST_DATA_PATH = "data/free_fall_2.csv" -data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500) +TEST_DATA_PATH = "data/harmonic_input_2.csv" +data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=10, skip_footer=2500) timespan = data_array[:, 0] - data_array[0, 0] sampling = np.mean(np.diff(timespan)) angle = data_array[:, 1] @@ -63,7 +63,7 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model: model.opt.timestep = sampling -HORIZON = 100 +HORIZON = 50 N_INTERVALS = len(timespan) // HORIZON - 1 timespan = timespan[: N_INTERVALS * HORIZON] angle = angle[: N_INTERVALS * HORIZON] diff --git a/examples/pendulum/models/pendulum.xml b/examples/pendulum/models/pendulum.xml index 8c141b6..c210f34 100644 --- a/examples/pendulum/models/pendulum.xml +++ b/examples/pendulum/models/pendulum.xml @@ -6,9 +6,9 @@ - + - + diff --git a/examples/pendulum/plots/learning_results.png b/examples/pendulum/plots/learning_results.png index 7959aa4..c1dd909 100644 Binary files a/examples/pendulum/plots/learning_results.png and b/examples/pendulum/plots/learning_results.png differ