Skip to content

Commit

Permalink
(Issue #99) Fix disc_episode_returns off-by-one error (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
Katze2664 authored Aug 13, 2024
1 parent 0ec9b86 commit 7ea7536
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
3 changes: 2 additions & 1 deletion mo_gymnasium/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,14 @@ def step(self, action):
infos, dict
), f"`info` dtype is {type(infos)} while supported dtype is `dict`. This may be due to usage of other wrappers in the wrong order."
self.episode_returns += rewards
self.episode_lengths += 1

# CHANGE: The discounted returns are also computed here
self.disc_episode_returns += rewards * np.repeat(self.gamma**self.episode_lengths, self.reward_dim).reshape(
self.episode_returns.shape
)

self.episode_lengths += 1

dones = np.logical_or(terminations, truncations)
num_dones = np.sum(dones)
if num_dones:
Expand Down
13 changes: 6 additions & 7 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ def go_to_8_3(env):
Goes to (8.2, -3) treasure, returns the rewards
"""
env.reset()
env.step(3) # right
env.step(1) # down
_, rewards, _, _, infos = env.step(1)
env.step(3) # action: right, rewards: [0, -1]
env.step(1) # action: down, rewards: [0, -1]
_, rewards, _, _, infos = env.step(1) # action: down, rewards: [8.2, -1]
return rewards, infos


Expand Down Expand Up @@ -98,10 +98,9 @@ def test_mo_record_ep_statistic():
assert info["episode"]["r"].shape == (2,)
assert info["episode"]["dr"].shape == (2,)
assert tuple(info["episode"]["r"]) == (np.float32(8.2), np.float32(-3.0))
assert tuple(np.round(info["episode"]["dr"], 2)) == (
np.float32(7.48),
np.float32(-2.82),
)
np.testing.assert_allclose(info["episode"]["dr"], [7.71538, -2.9109], rtol=0, atol=1e-2)
# 0 * 0.97**0 + 0 * 0.97**1 + 8.2 * 0.97**2 == 7.71538
# -1 * 0.97**0 + -1 * 0.97**1 + -1 * 0.97**2 == -2.9109
assert isinstance(info["episode"]["l"], np.int32)
assert info["episode"]["l"] == 3
assert isinstance(info["episode"]["t"], np.float32)
Expand Down

0 comments on commit 7ea7536

Please sign in to comment.