diff --git a/momadm_benchmarks/envs/multiwalker/momultiwalker_v0.py b/momadm_benchmarks/envs/multiwalker/momultiwalker_v0.py new file mode 100644 index 00000000..c83439c2 --- /dev/null +++ b/momadm_benchmarks/envs/multiwalker/momultiwalker_v0.py @@ -0,0 +1,5 @@ +"""Multiwalker domain environment for multi-objective optimization.""" +from momadm_benchmarks.envs.multiwalker.multiwalker import env, parallel_env, raw_env + + +__all__ = ["env", "parallel_env", "raw_env"] diff --git a/momadm_benchmarks/envs/multiwalker/multiwalker.py b/momadm_benchmarks/envs/multiwalker/multiwalker.py index 5bdcc00c..99803465 100644 --- a/momadm_benchmarks/envs/multiwalker/multiwalker.py +++ b/momadm_benchmarks/envs/multiwalker/multiwalker.py @@ -11,25 +11,53 @@ from pettingzoo.utils import wrappers from momadm_benchmarks.envs.multiwalker.multiwalker_base import MOMultiWalkerEnv as _env +from momadm_benchmarks.utils.conversions import mo_aec_to_parallel from momadm_benchmarks.utils.env import MOAECEnv def env(**kwargs): - """Autowrapper for the multiwalker domain. + """Returns the env in `AEC` format. Args: **kwargs: keyword args to forward to the raw_env function. + Returns: + A fully wrapped AEC env. + """ + env = raw_env(**kwargs) + return env + + +def parallel_env(**kwargs): + """Returns the env in `parallel` format. + + Args: + **kwargs: keyword args to forward to the raw_env function. + + Returns: + A fully wrapped parallel env. + """ + env = raw_env(**kwargs) + env = mo_aec_to_parallel(env) + return env + + +def raw_env(**kwargs): + """Returns the wrapped env in `AEC` format. + + Args: + **kwargs: keyword args to forward to create the `MOMultiwalker` environment. + Returns: A fully wrapped env. """ - env = mo_env(**kwargs) + env = MOMultiwalker(**kwargs) env = wrappers.ClipOutOfBoundsWrapper(env) env = wrappers.OrderEnforcingWrapper(env) return env -class mo_env(MOAECEnv, pz_multiwalker): +class MOMultiwalker(MOAECEnv, pz_multiwalker): """Environment for MO Multiwalker problem domain. The init method takes in environment arguments and should define the following attributes: diff --git a/tests/all_modules.py b/tests/all_modules.py index 89a060a2..50fc4c51 100644 --- a/tests/all_modules.py +++ b/tests/all_modules.py @@ -1,6 +1,7 @@ from momadm_benchmarks.envs.beach_domain import mobeach_v0 - +from momadm_benchmarks.envs.multiwalker import momultiwalker_v0 all_environments = { "mobeach_v0": mobeach_v0, + "momultiwalker_v0": momultiwalker_v0, }