Skip to content

Commit

Permalink
Merge pull request #39 from OliEfr/master
Browse files Browse the repository at this point in the history
fix check for self._init_step_no evaluating to false if 0
  • Loading branch information
robfiras authored Sep 9, 2024
2 parents 39a101c + cb26731 commit 41a020c
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion loco_mujoco/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def setup(self, obs):
if self.trajectories is not None:
if self._random_start:
sample = self.trajectories.reset_trajectory()
elif self._init_step_no:
elif self._init_step_no is not None:
traj_len = self.trajectories.trajectory_length
n_traj = self.trajectories.number_of_trajectories
assert self._init_step_no <= traj_len * n_traj
Expand Down
2 changes: 1 addition & 1 deletion loco_mujoco/environments/gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,4 @@ def _convert_space(space):
low = np.min(space.low)
high = np.max(space.high)
shape = space.shape
return Box(low, high, shape, np.float64)
return Box(low, high, shape, np.float64)
4 changes: 2 additions & 2 deletions loco_mujoco/environments/humanoids/base_humanoid_4_ages.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def setup(self, obs):
sample = self.trajectories.reset_trajectory(traj_no=traj_no)
else:
sample = self.trajectories.reset_trajectory()
elif self._init_step_no:
elif self._init_step_no is not None:
traj_len = self.trajectories.trajectory_length
n_traj = self.trajectories.nnumber_of_trajectories
n_traj = self.trajectories.number_of_trajectories
assert self._init_step_no <= traj_len * n_traj
substep_no = int(self._init_step_no % traj_len)
traj_no = int(self._init_step_no / traj_len)
Expand Down
4 changes: 2 additions & 2 deletions loco_mujoco/environments/quadrupeds/unitreeA1.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,9 @@ def setup(self, obs):
if self.setup_random_rot:
angle = np.random.uniform(0, 2 * np.pi)
sample = rotate_obs(sample, angle, *self._get_relevant_idx_rotation())
elif self._init_step_no:
elif self._init_step_no is not None:
traj_len = self.trajectories.trajectory_length
n_traj = self.trajectories.nnumber_of_trajectories
n_traj = self.trajectories.number_of_trajectories
assert self._init_step_no <= traj_len * n_traj
substep_no = int(self._init_step_no % traj_len)
traj_no = int(self._init_step_no / traj_len)
Expand Down

0 comments on commit 41a020c

Please sign in to comment.