Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for different vector autoreset modes #1227

Merged
87 changes: 78 additions & 9 deletions gymnasium/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.vector.vector_env import ArrayType, VectorEnv
from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv


__all__ = ["AsyncVectorEnv", "AsyncState"]
Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(
| None
) = None,
observation_mode: str | Space = "same",
autoreset_mode: str | AutoresetMode = AutoresetMode.NEXT_STEP,
):
"""Vectorized environment that runs multiple environments in parallel.

Expand All @@ -120,6 +121,7 @@ def __init__(
'different' defines that there can be multiple observation spaces with different parameters though requires the same shape and dtype,
warning, may raise unexpected errors. Passing a ``Tuple[Space, Space]`` object allows defining a custom ``single_observation_space`` and
``observation_space``, warning, may raise unexpected errors.
autoreset_mode: The Autoreset Mode used, see todo for more details.

Warnings:
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
Expand All @@ -135,7 +137,15 @@ def __init__(
self.env_fns = env_fns
self.shared_memory = shared_memory
self.copy = copy
self.context = context
self.daemon = daemon
self.worker = worker
self.observation_mode = observation_mode
self.autoreset_mode = (
autoreset_mode
if isinstance(autoreset_mode, AutoresetMode)
else AutoresetMode(autoreset_mode)
)

self.num_envs = len(env_fns)

Expand All @@ -145,6 +155,7 @@ def __init__(

# As we support `make_vec(spec)` then we can't include a `spec = dummy_env.spec` as this doesn't guarantee we can actual recreate the vector env.
self.metadata = dummy_env.metadata
self.metadata["autoreset_mode"] = self.autoreset_mode
self.render_mode = dummy_env.render_mode

self.single_action_space = dummy_env.action_space
Expand Down Expand Up @@ -211,6 +222,7 @@ def __init__(
parent_pipe,
_obs_buffer,
self.error_queue,
self.autoreset_mode,
),
)

Expand Down Expand Up @@ -287,9 +299,32 @@ def reset_async(
str(self._state.value),
)

for pipe, env_seed in zip(self.parent_pipes, seed):
env_kwargs = {"seed": env_seed, "options": options}
pipe.send(("reset", env_kwargs))
if options is not None and "reset_mask" in options:
reset_mask = options.pop("reset_mask")
assert isinstance(
reset_mask, np.ndarray
), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}"
assert reset_mask.shape == (
self.num_envs,
), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}"
assert (
reset_mask.dtype == np.bool_
), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}"
assert np.any(
reset_mask
), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}"

for pipe, env_seed, env_reset in zip(self.parent_pipes, seed, reset_mask):
if env_reset:
env_kwargs = {"seed": env_seed, "options": options}
pipe.send(("reset", env_kwargs))
else:
pipe.send(("reset-noop", None))
else:
for pipe, env_seed in zip(self.parent_pipes, seed):
env_kwargs = {"seed": env_seed, "options": options}
pipe.send(("reset", env_kwargs))

self._state = AsyncState.WAITING_RESET

def reset_wait(
Expand Down Expand Up @@ -688,11 +723,13 @@ def _async_worker(
parent_pipe: Connection,
shared_memory: multiprocessing.Array | dict[str, Any] | tuple[Any, ...],
error_queue: Queue,
autoreset_mode: AutoresetMode,
):
env = env_fn()
observation_space = env.observation_space
action_space = env.action_space
autoreset = False
observation = None

parent_pipe.close()

Expand All @@ -709,19 +746,51 @@ def _async_worker(
observation = None
autoreset = False
pipe.send(((observation, info), True))
elif command == "reset-noop":
pipe.send(((observation, {}), True))
elif command == "step":
if autoreset:
observation, info = env.reset()
reward, terminated, truncated = 0, False, False
else:
if autoreset_mode == AutoresetMode.NEXT_STEP:
if autoreset:
observation, info = env.reset()
reward, terminated, truncated = 0, False, False
else:
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
autoreset = terminated or truncated
elif autoreset_mode == AutoresetMode.SAME_STEP:
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
autoreset = terminated or truncated

if terminated or truncated:
reset_observation, reset_info = env.reset()

info = {
"final_info": info,
"final_obs": observation,
**reset_info,
}
observation = reset_observation
elif autoreset_mode == AutoresetMode.DISABLED:
assert autoreset is False
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
else:
raise ValueError(f"Unexpected autoreset_mode: {autoreset_mode}")

if shared_memory:
write_to_shared_memory(
Expand Down
120 changes: 95 additions & 25 deletions gymnasium/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
create_empty_array,
iterate,
)
from gymnasium.vector.vector_env import ArrayType, VectorEnv
from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv


__all__ = ["SyncVectorEnv"]
Expand Down Expand Up @@ -65,6 +65,7 @@ def __init__(
env_fns: Iterator[Callable[[], Env]] | Sequence[Callable[[], Env]],
copy: bool = True,
observation_mode: str | Space = "same",
autoreset_mode: str | AutoresetMode = AutoresetMode.NEXT_STEP,
):
"""Vectorized environment that serially runs multiple environments.

Expand All @@ -74,13 +75,22 @@ def __init__(
observation_mode: Defines how environment observation spaces should be batched. 'same' defines that there should be ``n`` copies of identical spaces.
'different' defines that there can be multiple observation spaces with the same length but different high/low values batched together. Passing a ``Space`` object
allows the user to set some custom observation space mode not covered by 'same' or 'different.'
autoreset_mode: The Autoreset Mode used, see todo for more details.

Raises:
RuntimeError: If the observation space of some sub-environment does not match observation_space
(or, by default, the observation space of the first sub-environment).
"""
self.copy = copy
super().__init__()

self.env_fns = env_fns
self.copy = copy
self.observation_mode = observation_mode
self.autoreset_mode = (
autoreset_mode
if isinstance(autoreset_mode, AutoresetMode)
else AutoresetMode(autoreset_mode)
)

# Initialise all sub-environments
self.envs = [env_fn() for env_fn in env_fns]
Expand All @@ -89,6 +99,7 @@ def __init__(
# As we support `make_vec(spec)` then we can't include a `spec = self.envs[0].spec` as this doesn't guarantee we can actual recreate the vector env.
self.num_envs = len(self.envs)
self.metadata = self.envs[0].metadata
self.metadata["autoreset_mode"] = self.autoreset_mode
self.render_mode = self.envs[0].render_mode

self.single_action_space = self.envs[0].action_space
Expand Down Expand Up @@ -130,6 +141,7 @@ def __init__(
), f"Sub-environment action space doesn't make the `single_action_space`, action_space={env.action_space}, single_action_space={self.single_action_space}"

# Initialise attributes used in `step` and `reset`
self._env_obs = [None for _ in range(self.num_envs)]
self._observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
Expand Down Expand Up @@ -175,23 +187,52 @@ def reset(
len(seed) == self.num_envs
), f"If seeds are passed as a list the length must match num_envs={self.num_envs} but got length={len(seed)}."

self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)

observations, infos = [], {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
env_obs, env_info = env.reset(seed=single_seed, options=options)
if options is not None and "reset_mask" in options:
reset_mask = options.pop("reset_mask")
assert isinstance(
reset_mask, np.ndarray
), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}"
assert reset_mask.shape == (
self.num_envs,
), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}"
assert (
reset_mask.dtype == np.bool_
), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}"
assert np.any(
reset_mask
), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}"

self._terminations[reset_mask] = False
self._truncations[reset_mask] = False
self._autoreset_envs[reset_mask] = False

infos = {}
for i, (env, single_seed, env_mask) in enumerate(
zip(self.envs, seed, reset_mask)
):
if env_mask:
self._env_obs[i], env_info = env.reset(
seed=single_seed, options=options
)

infos = self._add_info(infos, env_info, i)
else:
self._terminations = np.zeros((self.num_envs,), dtype=np.bool_)
self._truncations = np.zeros((self.num_envs,), dtype=np.bool_)
self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_)

infos = {}
for i, (env, single_seed) in enumerate(zip(self.envs, seed)):
self._env_obs[i], env_info = env.reset(
seed=single_seed, options=options
)

observations.append(env_obs)
infos = self._add_info(infos, env_info, i)
infos = self._add_info(infos, env_info, i)

# Concatenate the observations
self._observations = concatenate(
self.single_observation_space, observations, self._observations
self.single_observation_space, self._env_obs, self._observations
)

self._autoreset_envs = np.zeros((self.num_envs,), dtype=np.bool_)

return deepcopy(self._observations) if self.copy else self._observations, infos

def step(
Expand All @@ -204,29 +245,58 @@ def step(
"""
actions = iterate(self.action_space, actions)

observations, infos = [], {}
infos = {}
for i, action in enumerate(actions):
if self._autoreset_envs[i]:
env_obs, env_info = self.envs[i].reset()

self._rewards[i] = 0.0
self._terminations[i] = False
self._truncations[i] = False
else:
if self.autoreset_mode == AutoresetMode.NEXT_STEP:
if self._autoreset_envs[i]:
self._env_obs[i], env_info = self.envs[i].reset()

self._rewards[i] = 0.0
self._terminations[i] = False
self._truncations[i] = False
else:
(
self._env_obs[i],
self._rewards[i],
self._terminations[i],
self._truncations[i],
env_info,
) = self.envs[i].step(action)
elif self.autoreset_mode == AutoresetMode.DISABLED:
# assumes that the user has correctly autoreset
assert not self._autoreset_envs[i], f"{self._autoreset_envs=}"
(
env_obs,
self._env_obs[i],
self._rewards[i],
self._terminations[i],
self._truncations[i],
env_info,
) = self.envs[i].step(action)
elif self.autoreset_mode == AutoresetMode.SAME_STEP:
(
self._env_obs[i],
self._rewards[i],
self._terminations[i],
self._truncations[i],
env_info,
) = self.envs[i].step(action)

if self._terminations[i] or self._truncations[i]:
infos = self._add_info(
infos,
{"final_obs": self._env_obs[i], "final_info": env_info},
i,
)

self._env_obs[i], env_info = self.envs[i].reset()
else:
raise ValueError(f"Unexpected autoreset mode, {self.autoreset_mode}")

observations.append(env_obs)
infos = self._add_info(infos, env_info, i)

# Concatenate the observations
self._observations = concatenate(
self.single_observation_space, observations, self._observations
self.single_observation_space, self._env_obs, self._observations
)
self._autoreset_envs = np.logical_or(self._terminations, self._truncations)

Expand Down
Loading
Loading