Skip to content

Commit

Permalink
MuJoco-v5 environments replace ndarray.flat.copy() with `ndarray.fl…
Browse files Browse the repository at this point in the history
…atten()` (#815)
  • Loading branch information
Kallinteris-Andreas authored Dec 5, 2023
1 parent 2790321 commit b879469
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 28 deletions.
6 changes: 3 additions & 3 deletions gymnasium/envs/mujoco/ant_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,14 @@ def step(self, action):
return observation, reward, terminated, False, info

def _get_obs(self):
position = self.data.qpos.flat.copy()
velocity = self.data.qvel.flat.copy()
position = self.data.qpos.flatten()
velocity = self.data.qvel.flatten()

if self._exclude_current_positions_from_observation:
position = position[2:]

if self._include_cfrc_ext_in_observation:
contact_force = self.contact_forces[1:].flat.copy()
contact_force = self.contact_forces[1:].flatten()
return np.concatenate((position, velocity, contact_force))
else:
return np.concatenate((position, velocity))
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/envs/mujoco/half_cheetah_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def step(self, action):
return observation, reward, False, False, info

def _get_obs(self):
position = self.data.qpos.flat.copy()
velocity = self.data.qvel.flat.copy()
position = self.data.qpos.flatten()
velocity = self.data.qvel.flatten()

if self._exclude_current_positions_from_observation:
position = position[1:]
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/envs/mujoco/hopper_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ def terminated(self):
return terminated

def _get_obs(self):
position = self.data.qpos.flat.copy()
velocity = np.clip(self.data.qvel.flat.copy(), -10, 10)
position = self.data.qpos.flatten()
velocity = np.clip(self.data.qvel.flatten(), -10, 10)

if self._exclude_current_positions_from_observation:
position = position[1:]
Expand Down
12 changes: 6 additions & 6 deletions gymnasium/envs/mujoco/humanoid_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,24 +454,24 @@ def terminated(self):
return terminated

def _get_obs(self):
position = self.data.qpos.flat.copy()
velocity = self.data.qvel.flat.copy()
position = self.data.qpos.flatten()
velocity = self.data.qvel.flatten()

if self._include_cinert_in_observation is True:
com_inertia = self.data.cinert[1:].flat.copy()
com_inertia = self.data.cinert[1:].flatten()
else:
com_inertia = np.array([])
if self._include_cvel_in_observation is True:
com_velocity = self.data.cvel[1:].flat.copy()
com_velocity = self.data.cvel[1:].flatten()
else:
com_velocity = np.array([])

if self._include_qfrc_actuator_in_observation is True:
actuator_forces = self.data.qfrc_actuator[6:].flat.copy()
actuator_forces = self.data.qfrc_actuator[6:].flatten()
else:
actuator_forces = np.array([])
if self._include_cfrc_ext_in_observation is True:
external_contact_forces = self.data.cfrc_ext[1:].flat.copy()
external_contact_forces = self.data.cfrc_ext[1:].flatten()
else:
external_contact_forces = np.array([])

Expand Down
12 changes: 6 additions & 6 deletions gymnasium/envs/mujoco/humanoidstandup_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,24 +405,24 @@ def __init__(
}

def _get_obs(self):
position = self.data.qpos.flat.copy()
velocity = self.data.qvel.flat.copy()
position = self.data.qpos.flatten()
velocity = self.data.qvel.flatten()

if self._include_cinert_in_observation is True:
com_inertia = self.data.cinert[1:].flat.copy()
com_inertia = self.data.cinert[1:].flatten()
else:
com_inertia = np.array([])
if self._include_cvel_in_observation is True:
com_velocity = self.data.cvel[1:].flat.copy()
com_velocity = self.data.cvel[1:].flatten()
else:
com_velocity = np.array([])

if self._include_qfrc_actuator_in_observation is True:
actuator_forces = self.data.qfrc_actuator[6:].flat.copy()
actuator_forces = self.data.qfrc_actuator[6:].flatten()
else:
actuator_forces = np.array([])
if self._include_cfrc_ext_in_observation is True:
external_contact_forces = self.data.cfrc_ext[1:].flat.copy()
external_contact_forces = self.data.cfrc_ext[1:].flatten()
else:
external_contact_forces = np.array([])

Expand Down
4 changes: 2 additions & 2 deletions gymnasium/envs/mujoco/pusher_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def reset_model(self):
def _get_obs(self):
return np.concatenate(
[
self.data.qpos.flat[:7],
self.data.qvel.flat[:7],
self.data.qpos.flatten()[:7],
self.data.qvel.flatten()[:7],
self.get_body_com("tips_arm"),
self.get_body_com("object"),
self.get_body_com("goal"),
Expand Down
6 changes: 3 additions & 3 deletions gymnasium/envs/mujoco/reacher_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,13 @@ def reset_model(self):
return self._get_obs()

def _get_obs(self):
theta = self.data.qpos.flat[:2]
theta = self.data.qpos.flatten()[:2]
return np.concatenate(
[
np.cos(theta),
np.sin(theta),
self.data.qpos.flat[2:],
self.data.qvel.flat[:2],
self.data.qpos.flatten()[2:],
self.data.qvel.flatten()[:2],
(self.get_body_com("fingertip") - self.get_body_com("target"))[:2],
]
)
4 changes: 2 additions & 2 deletions gymnasium/envs/mujoco/swimmer_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ def step(self, action):
return observation, reward, False, False, info

def _get_obs(self):
position = self.data.qpos.flat.copy()
velocity = self.data.qvel.flat.copy()
position = self.data.qpos.flatten()
velocity = self.data.qvel.flatten()

if self._exclude_current_positions_from_observation:
position = position[2:]
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/envs/mujoco/walker2d_v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ def terminated(self):
return terminated

def _get_obs(self):
position = self.data.qpos.flat.copy()
velocity = np.clip(self.data.qvel.flat.copy(), -10, 10)
position = self.data.qpos.flatten()
velocity = np.clip(self.data.qvel.flatten(), -10, 10)

if self._exclude_current_positions_from_observation:
position = position[1:]
Expand Down

0 comments on commit b879469

Please sign in to comment.