Skip to content

Commit

Permalink
Fix: have access to terminal_observation in infos (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower authored Nov 28, 2023
2 parents fa38cdd + e49c5e1 commit 7a16fe8
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
7 changes: 6 additions & 1 deletion supersuit/vector/markov_vector_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,14 @@ def step(self, actions):
infs = [infos.get(agent, {}) for agent in self.par_env.possible_agents]

if env_done:
observations, infs = self.reset()
observations, reset_infs = self.reset()
else:
observations = self.concat_obs(observations)
# empty infos for reset infs
reset_infs = [{} for _ in range(len(self.par_env.possible_agents))]
# combine standard infos and reset infos
infs = [{**inf, **reset_inf} for inf, reset_inf in zip(infs, reset_infs)]

assert (
self.black_death or self.par_env.agents == self.par_env.possible_agents
), "MarkovVectorEnv does not support environments with varying numbers of active agents unless black_death is set to True"
Expand Down
26 changes: 26 additions & 0 deletions test/test_vector/test_pettingzoo_to_vec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy

import numpy as np
import pytest
from pettingzoo.butterfly import knights_archers_zombies_v10
from pettingzoo.mpe import simple_spread_v3, simple_world_comm_v3
Expand Down Expand Up @@ -89,3 +90,28 @@ def test_env_black_death_wrapper():
for i in range(300):
actions = [env.action_space.sample() for i in range(env.num_envs)]
obss, rews, terms, truncs, infos = env.step(actions)


def test_terminal_obs_are_returned():
"""
If we reach (and pass) the end of the episode, the last observation is returned in the info dict.
"""
max_cycles = 300
env = knights_archers_zombies_v10.parallel_env(spawn_rate=50, max_cycles=300)
env = black_death_v3(env)
env = pettingzoo_env_to_vec_env_v1(env)
env.reset(seed=42)

# run past max_cycles or until terminated - causing the env to reset and continue
for _ in range(0, max_cycles + 10):
actions = [env.action_space.sample() for i in range(env.num_envs)]
_, _, terms, truncs, infos = env.step(actions)

env_done = (np.array(terms) | np.array(truncs)).all()

if env_done:
# check we have infos for all agents
assert len(infos) == len(env.par_env.possible_agents)
# check infos contain terminal_observation
for info in infos:
assert "terminal_observation" in info

0 comments on commit 7a16fe8

Please sign in to comment.