Skip to content

Commit

Permalink
Updated Unitree has fallen method.
Browse files Browse the repository at this point in the history
- has fallen method acts on the unmodified state.
- the dataset is not checked anymore for terminal states in the unitree.
  • Loading branch information
robfiras committed Oct 17, 2023
1 parent 376f90f commit f533020
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions loco_mujoco/environments/quadrupeds/unitreeA1.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,6 @@ def create_dataset(self, ignore_keys=None):
dataset = self.trajectories.create_dataset(ignore_keys=ignore_keys,
state_callback=self._modify_observation_callback,
state_callback_params=state_callback_params)
# check that all state in the dataset satisfy the has fallen method.
for state in dataset["states"]:
has_fallen, msg = self._has_fallen(state, return_err_msg=True)
if has_fallen:
err_msg = "Some of the states in the created dataset are terminal states. " \
"This should not happen.\n\nViolations:\n"
err_msg += msg
raise ValueError(err_msg)
else:
raise ValueError("No trajectory was passed to the environment. "
"To create a dataset pass a trajectory first.")
Expand Down Expand Up @@ -336,13 +328,21 @@ def _has_fallen(self, obs, return_err_msg=False):
"""

trunk_euler = self._get_from_obs(obs, ["q_trunk_list", "q_trunk_tilt"])
trunk_height = self._get_from_obs(obs, ["q_trunk_tz"])

trunk_list_condition = (trunk_euler[0] < -0.2793) or (trunk_euler[0] > 0.2793)
trunk_tilt_condition = (trunk_euler[1] < -0.192) or (trunk_euler[1] > 0.192)
trunk_height_condition = trunk_height[0] < -.24
trunk_condition = (trunk_height_condition)
trunk_condition = (trunk_list_condition or trunk_tilt_condition or trunk_height_condition)

if return_err_msg:
error_msg = ""
if trunk_height_condition:
if trunk_list_condition:
error_msg += "trunk_list_condition violated.\n"
elif trunk_tilt_condition:
error_msg += "trunk_tilt_condition violated.\n"
elif trunk_height_condition:
error_msg += "trunk_height_condition violated. %f \n" % trunk_height

return trunk_condition, error_msg
Expand Down

0 comments on commit f533020

Please sign in to comment.