From e8897ff6779d0740d907774084d9b69e860a473e Mon Sep 17 00:00:00 2001 From: Nidhish Shah <55269918+nidhishs@users.noreply.github.com> Date: Thu, 5 Jan 2023 11:36:17 +0100 Subject: [PATCH] Added vector env support to StepAPICompatibility wrapper. (#238) --- gymnasium/wrappers/step_api_compatibility.py | 13 +++++-------- tests/wrappers/test_step_compatibility.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/gymnasium/wrappers/step_api_compatibility.py b/gymnasium/wrappers/step_api_compatibility.py index 06557c0a8..87d057f62 100644 --- a/gymnasium/wrappers/step_api_compatibility.py +++ b/gymnasium/wrappers/step_api_compatibility.py @@ -1,10 +1,7 @@ """Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API.""" import gymnasium as gym from gymnasium.logger import deprecation -from gymnasium.utils.step_api_compatibility import ( - convert_to_done_step_api, - convert_to_terminated_truncated_step_api, -) +from gymnasium.utils.step_api_compatibility import step_api_compatibility class StepAPICompatibility(gym.Wrapper): @@ -36,6 +33,7 @@ def __init__(self, env: gym.Env, output_truncation_bool: bool = True): output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) """ super().__init__(env) + self.is_vector_env = isinstance(env.unwrapped, gym.vector.VectorEnv) self.output_truncation_bool = output_truncation_bool if not self.output_truncation_bool: deprecation( @@ -52,7 +50,6 @@ def step(self, action): (observation, reward, terminated, truncated, info) or (observation, reward, done, info) """ step_returns = self.env.step(action) - if self.output_truncation_bool: - return convert_to_terminated_truncated_step_api(step_returns) - else: - return convert_to_done_step_api(step_returns) + return step_api_compatibility( + step_returns, self.output_truncation_bool, self.is_vector_env + ) diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index f4c7f465c..79cf6451c 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -1,7 +1,9 @@ +import numpy as np import pytest import gymnasium as gym from gymnasium.spaces import Discrete +from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv from gymnasium.wrappers import StepAPICompatibility @@ -54,6 +56,20 @@ def test_step_compatibility_to_old_api(env): assert isinstance(done, bool) +@pytest.mark.parametrize("vector_env", [SyncVectorEnv, AsyncVectorEnv]) +def test_vector_env_step_compatibility_to_old_api(vector_env): + num_envs = 2 + env = vector_env([NewStepEnv for _ in range(num_envs)]) + old_env = StepAPICompatibility(env, False) + + step_returns = old_env.step([0] * num_envs) + assert len(step_returns) == 4 + _, _, dones, _ = step_returns + assert isinstance(dones, np.ndarray) + for done in dones: + assert isinstance(done, np.bool_) + + @pytest.mark.parametrize("apply_api_compatibility", [None, True, False]) def test_step_compatibility_in_make(apply_api_compatibility): gym.register("OldStepEnv-v0", entry_point=OldStepEnv)