Skip to content

Commit

Permalink
some fix in pendulum xml
Browse files Browse the repository at this point in the history
  • Loading branch information
simeon-ned committed Oct 1, 2024
1 parent 4576add commit a96a0b6
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 21 deletions.
17 changes: 1 addition & 16 deletions examples/pendulum/01_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:

# Load test data
LEARNING_DATA_PATH = "data/free_fall_2.csv"
data_array = np.genfromtxt(LEARNING_DATA_PATH, delimiter=",", skip_header=100, skip_footer=1000)
data_array = np.genfromtxt(LEARNING_DATA_PATH, delimiter=",", skip_header=10, skip_footer=2500)
timespan = data_array[:, 0] - data_array[0, 0]
sampling = np.mean(np.diff(timespan))
angle = data_array[:, 1]
Expand Down Expand Up @@ -90,21 +90,6 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
)


@jax.jit
def rollout_errors(parameters, states, controls):
# TODO: Use the full trajecttory in shouting not only las point
interval_initial_states = states[::HORIZON]
interval_terminal_states = states[HORIZON + 1 :][::HORIZON]
interval_controls = jnp.reshape(controls, (N_INTERVALS, HORIZON))
batched_rollout = jax.vmap(rollout_trajectory, in_axes=(None, None, 0, 0))
batched_states_trajectories = batched_rollout(parameters, mjx_model, interval_initial_states, interval_controls)
predicted_terminal_points = batched_states_trajectories[:, -1, :]
loss = jnp.mean(
optax.l2_loss(predicted_terminal_points[:-1], interval_terminal_states)
) # + 0.05*jnp.mean(optax.huber_loss(parameters, jnp.zeros_like(parameters)))
return loss


@jax.jit
def rollout_errors(parameters, states, controls):
# TODO: Use the full trajecttory in shouting not only las point
Expand Down
6 changes: 3 additions & 3 deletions examples/pendulum/02_simulation_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:
mjx_model = mjx.put_model(model)

# Load test data
TEST_DATA_PATH = "data/free_fall_2.csv"
data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=100, skip_footer=2500)
TEST_DATA_PATH = "data/harmonic_input_2.csv"
data_array = np.genfromtxt(TEST_DATA_PATH, delimiter=",", skip_header=10, skip_footer=2500)
timespan = data_array[:, 0] - data_array[0, 0]
sampling = np.mean(np.diff(timespan))
angle = data_array[:, 1]
Expand All @@ -63,7 +63,7 @@ def parameters_map(parameters: jnp.ndarray, model: mjx.Model) -> mjx.Model:

model.opt.timestep = sampling

HORIZON = 100
HORIZON = 50
N_INTERVALS = len(timespan) // HORIZON - 1
timespan = timespan[: N_INTERVALS * HORIZON]
angle = angle[: N_INTERVALS * HORIZON]
Expand Down
4 changes: 2 additions & 2 deletions examples/pendulum/models/pendulum.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
<worldbody>
<geom name="floor" size="1 1 0.1" type="plane" rgba="0.9 0.9 0.9 1"/>
<body name="pendulum_body" pos="0 0 0.5">
<joint name="hinge" pos="0 0 0" axis="0 1 0" damping="0.0005" frictionloss="0.000001"/>
<joint name="hinge" pos="0 0 0" axis="0 1 0" damping="0.001" frictionloss="0.0001"/>
<geom name="mass" size="0.01" rgba="1 0 0 1"/>
<inertial pos="0 0 -0.1" mass="0.1" diaginertia="0.0001 0.0001 0.0001"/>
<inertial pos="0 0 -0.1" mass="0.1" diaginertia="0.000001 0.000001 0.000001"/>
</body>
</worldbody>

Expand Down
Binary file modified examples/pendulum/plots/learning_results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit a96a0b6

Please sign in to comment.