From b34c25b368aad76a23b58044d986df2b1cfede5f Mon Sep 17 00:00:00 2001 From: Ian Fan Date: Wed, 11 Jan 2023 20:09:37 +0000 Subject: [PATCH] Fix type annotations of `callable` to `Callable` (#259) --- gymnasium/envs/registration.py | 2 +- gymnasium/utils/play.py | 2 +- gymnasium/vector/__init__.py | 9 +++++---- gymnasium/vector/async_vector_env.py | 8 ++++---- gymnasium/vector/utils/misc.py | 7 ++++++- gymnasium/vector/utils/numpy_utils.py | 4 ++-- tests/spaces/test_spaces.py | 4 ++-- tests/testing_env.py | 7 ++++--- tests/utils/test_env_checker.py | 8 ++++---- tests/utils/test_passive_env_checker.py | 6 +++--- 10 files changed, 32 insertions(+), 25 deletions(-) diff --git a/gymnasium/envs/registration.py b/gymnasium/envs/registration.py index b6bae1eec..625acf624 100644 --- a/gymnasium/envs/registration.py +++ b/gymnasium/envs/registration.py @@ -45,7 +45,7 @@ ) -def load(name: str) -> callable: +def load(name: str) -> Callable: """Loads an environment with name and returns an environment creation function. Args: diff --git a/gymnasium/utils/play.py b/gymnasium/utils/play.py index 1ea4d69d9..79a92c729 100644 --- a/gymnasium/utils/play.py +++ b/gymnasium/utils/play.py @@ -321,7 +321,7 @@ class PlayPlot: """ def __init__( - self, callback: callable, horizon_timesteps: int, plot_names: List[str] + self, callback: Callable, horizon_timesteps: int, plot_names: List[str] ): """Constructor of :class:`PlayPlot`. diff --git a/gymnasium/vector/__init__.py b/gymnasium/vector/__init__.py index 95a8df1dd..b3c8db795 100644 --- a/gymnasium/vector/__init__.py +++ b/gymnasium/vector/__init__.py @@ -1,7 +1,8 @@ """Module for vector environments.""" -from typing import Iterable, List, Optional, Union +from typing import Callable, Iterable, List, Optional, Union import gymnasium as gym +from gymnasium.core import Env from gymnasium.vector.async_vector_env import AsyncVectorEnv from gymnasium.vector.sync_vector_env import SyncVectorEnv from gymnasium.vector.vector_env import VectorEnv, VectorEnvWrapper @@ -14,7 +15,7 @@ def make( id: str, num_envs: int = 1, asynchronous: bool = True, - wrappers: Optional[Union[callable, List[callable]]] = None, + wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None, disable_env_checker: Optional[bool] = None, **kwargs, ) -> VectorEnv: @@ -43,12 +44,12 @@ def make( The vectorized environment. """ - def create_env(env_num: int): + def create_env(env_num: int) -> Callable[[], Env]: """Creates an environment that can enable or disable the environment checker.""" # If the env_num > 0 then disable the environment checker otherwise use the parameter _disable_env_checker = True if env_num > 0 else disable_env_checker - def _make_env(): + def _make_env() -> Env: env = gym.envs.registration.make( id, disable_env_checker=_disable_env_checker, diff --git a/gymnasium/vector/async_vector_env.py b/gymnasium/vector/async_vector_env.py index b1e12d8ee..ffb22d94d 100644 --- a/gymnasium/vector/async_vector_env.py +++ b/gymnasium/vector/async_vector_env.py @@ -4,13 +4,13 @@ import time from copy import deepcopy from enum import Enum -from typing import List, Optional, Sequence, Tuple, Union +from typing import Callable, List, Optional, Sequence, Tuple, Union import numpy as np import gymnasium as gym from gymnasium import logger -from gymnasium.core import ObsType +from gymnasium.core import Env, ObsType from gymnasium.error import ( AlreadyPendingCallError, ClosedEnvironmentError, @@ -59,14 +59,14 @@ class AsyncVectorEnv(VectorEnv): def __init__( self, - env_fns: Sequence[callable], + env_fns: Sequence[Callable[[], Env]], observation_space: Optional[gym.Space] = None, action_space: Optional[gym.Space] = None, shared_memory: bool = True, copy: bool = True, context: Optional[str] = None, daemon: bool = True, - worker: Optional[callable] = None, + worker: Optional[Callable] = None, ): """Vectorized environment that runs multiple environments in parallel. diff --git a/gymnasium/vector/utils/misc.py b/gymnasium/vector/utils/misc.py index 7965aafe2..c8cd1f368 100644 --- a/gymnasium/vector/utils/misc.py +++ b/gymnasium/vector/utils/misc.py @@ -1,6 +1,11 @@ """Miscellaneous utilities.""" +from __future__ import annotations + import contextlib import os +from collections.abc import Callable + +from gymnasium.core import Env __all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"] @@ -9,7 +14,7 @@ class CloudpickleWrapper: """Wrapper that uses cloudpickle to pickle and unpickle the result.""" - def __init__(self, fn: callable): + def __init__(self, fn: Callable[[], Env]): """Cloudpickle wrapper for a function.""" self.fn = fn diff --git a/gymnasium/vector/utils/numpy_utils.py b/gymnasium/vector/utils/numpy_utils.py index f8e7f9236..b5e3265ec 100644 --- a/gymnasium/vector/utils/numpy_utils.py +++ b/gymnasium/vector/utils/numpy_utils.py @@ -1,7 +1,7 @@ """Numpy utility functions: concatenate space samples and create empty array.""" from collections import OrderedDict from functools import singledispatch -from typing import Iterable, Union +from typing import Callable, Iterable, Union import numpy as np @@ -84,7 +84,7 @@ def _concatenate_custom(space, items, out): @singledispatch def create_empty_array( - space: Space, n: int = 1, fn: callable = np.zeros + space: Space, n: int = 1, fn: Callable[..., np.ndarray] = np.zeros ) -> Union[tuple, dict, np.ndarray]: """Create an empty (possibly nested) numpy array. diff --git a/tests/spaces/test_spaces.py b/tests/spaces/test_spaces.py index 932601e19..54151fa05 100644 --- a/tests/spaces/test_spaces.py +++ b/tests/spaces/test_spaces.py @@ -3,7 +3,7 @@ import json # note: ujson fails this test due to float equality import pickle import tempfile -from typing import List, Union +from typing import Callable, List, Union import numpy as np import pytest @@ -283,7 +283,7 @@ def test_space_sample_mask(space: Space, mask, n_trials: int = 100): elif isinstance(space, MultiDiscrete): # Due to the multi-axis capability of MultiDiscrete, these functions need to be recursive and that the expected / observed numpy are of non-regular shapes def _generate_frequency( - _dim: Union[np.ndarray, int], _mask, func: callable + _dim: Union[np.ndarray, int], _mask, func: Callable ) -> List: if isinstance(_dim, np.ndarray): return [ diff --git a/tests/testing_env.py b/tests/testing_env.py index a066f4306..5896c2680 100644 --- a/tests/testing_env.py +++ b/tests/testing_env.py @@ -2,6 +2,7 @@ from __future__ import annotations import types +from collections.abc import Callable from typing import Any import gymnasium as gym @@ -45,9 +46,9 @@ def __init__( self, action_space: spaces.Space = spaces.Box(0, 1, (1,)), observation_space: spaces.Space = spaces.Box(0, 1, (1,)), - reset_func: callable = basic_reset_func, - step_func: callable = new_step_func, - render_func: callable = basic_render_func, + reset_func: Callable = basic_reset_func, + step_func: Callable = new_step_func, + render_func: Callable = basic_render_func, metadata: dict[str, Any] = {"render_modes": []}, render_mode: str | None = None, spec: EnvSpec = EnvSpec( diff --git a/tests/utils/test_env_checker.py b/tests/utils/test_env_checker.py index 498ede073..74809d1dd 100644 --- a/tests/utils/test_env_checker.py +++ b/tests/utils/test_env_checker.py @@ -1,7 +1,7 @@ """Tests that the `env_checker` runs as expects and all errors are possible.""" import re import warnings -from typing import Tuple, Union +from typing import Callable, Tuple, Union import numpy as np import pytest @@ -106,7 +106,7 @@ def _reset_default_seed(self: GenericTestEnv, seed="Error", options=None): ], ], ) -def test_check_reset_seed(test, func: callable, message: str): +def test_check_reset_seed(test, func: Callable, message: str): """Tests the check reset seed function works as expected.""" if test is UserWarning: with pytest.warns( @@ -175,7 +175,7 @@ def _return_info_not_dict(self, seed=None, options=None): ], ], ) -def test_check_reset_return_type(test, func: callable, message: str): +def test_check_reset_return_type(test, func: Callable, message: str): """Tests the check `env.reset()` function has a correct return type.""" with pytest.raises(test, match=f"^{re.escape(message)}$"): @@ -194,7 +194,7 @@ def test_check_reset_return_type(test, func: callable, message: str): ], ], ) -def test_check_reset_return_info_deprecation(test, func: callable, message: str): +def test_check_reset_return_info_deprecation(test, func: Callable, message: str): """Tests that return_info has been correct deprecated as an argument to `env.reset()`.""" with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"): diff --git a/tests/utils/test_passive_env_checker.py b/tests/utils/test_passive_env_checker.py index 96710ca40..ee0a4295f 100644 --- a/tests/utils/test_passive_env_checker.py +++ b/tests/utils/test_passive_env_checker.py @@ -1,6 +1,6 @@ import re import warnings -from typing import Dict, Union +from typing import Callable, Dict, Union import numpy as np import pytest @@ -297,7 +297,7 @@ def _reset_result(self, seed=None, options=None): ], ], ) -def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: Dict): +def test_passive_env_reset_checker(test, func: Callable, message: str, kwargs: Dict): """Tests the passive env reset check""" if test is UserWarning: with pytest.warns( @@ -376,7 +376,7 @@ def _modified_step( ], ) def test_passive_env_step_checker( - test: Union[UserWarning, type], func: callable, message: str + test: Union[UserWarning, type], func: Callable, message: str ): """Tests the passive env step checker.""" if test is UserWarning: