Skip to content

Commit

Permalink
Add mirror DST
Browse files Browse the repository at this point in the history
  • Loading branch information
ffelten committed Jan 15, 2024
1 parent 91ab390 commit 907b6ad
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
12 changes: 11 additions & 1 deletion mo_gymnasium/envs/deep_sea_treasure/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from gymnasium.envs.registration import register

from mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure import CONCAVE_MAP
from mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure import (
CONCAVE_MAP,
MIRRORED_MAP,
)


register(
Expand All @@ -15,3 +18,10 @@
max_episode_steps=100,
kwargs={"dst_map": CONCAVE_MAP},
)

register(
id="deep-sea-treasure-mirrored-v0",
entry_point="mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure:DeepSeaTreasure",
max_episode_steps=100,
kwargs={"dst_map": MIRRORED_MAP},
)
46 changes: 40 additions & 6 deletions mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@
np.array([124.0, -19]),
]

# As in Felten et al. 2022, same PF as concave, just harder map
MIRRORED_MAP = np.array(
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, -10, -10, 2.0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, -10, -10, -10, -10, 3.0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, -10, -10, -10, -10, -10, -10, 5.0, 8.0, 16.0, 0, 0, 0, 0],
[0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0, 0, 0],
[0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0, 0, 0],
[0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 24.0, 50.0, 0, 0],
[0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0],
[0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 74.0, 0],
[0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 124.0],
]
)


class DeepSeaTreasure(gym.Env, EzPickle):
"""
Expand Down Expand Up @@ -115,8 +132,16 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float

# The map of the deep sea treasure (convex version)
self.sea_map = dst_map
if np.all(dst_map == DEFAULT_MAP):
self.map_name = "convex"
elif np.all(dst_map == CONCAVE_MAP):
self.map_name = "concave"
elif np.all(dst_map == MIRRORED_MAP):
self.map_name = "mirrored"
else:
raise ValueError("Invalid map")
print(f"Using {self.map_name} map")
self._pareto_front = CONVEX_FRONT if np.all(dst_map == DEFAULT_MAP) else CONCAVE_FRONT
assert self.sea_map.shape == DEFAULT_MAP.shape, "The map's shape must be 11x11"

self.dir = {
0: np.array([-1, 0], dtype=np.int32), # up
Expand All @@ -130,7 +155,7 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float
if self.float_state:
self.observation_space = Box(low=0.0, high=1.0, shape=(2,), dtype=obs_type)
else:
self.observation_space = Box(low=0, high=10, shape=(2,), dtype=obs_type)
self.observation_space = Box(low=0, high=len(self.sea_map[0]), shape=(2,), dtype=obs_type)

# action space specification: 1 dimension, 0 up, 1 down, 2 left, 3 right
self.action_space = Discrete(4)
Expand All @@ -144,11 +169,15 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float
self.current_state = np.array([0, 0], dtype=np.int32)

# pygame
self.window_size = (min(64 * self.sea_map.shape[1], 512), min(64 * self.sea_map.shape[0], 512))
ratio = self.sea_map.shape[1] / self.sea_map.shape[0]
padding = 10
self.pix_inside = (min(64 * self.sea_map.shape[1], 512) * ratio, min(64 * self.sea_map.shape[0], 512))
# adding some padding on the sides
self.window_size = (self.pix_inside[0] + 2 * padding, self.pix_inside[1])
# The size of a single grid square in pixels
self.pix_square_size = (
self.window_size[1] // self.sea_map.shape[1] + 1,
self.window_size[0] // self.sea_map.shape[0] + 1,
self.pix_inside[0] // self.sea_map.shape[1] + 1,
self.pix_inside[1] // self.sea_map.shape[0] + 1, # watch out for axis inversions here
)
self.window = None
self.clock = None
Expand Down Expand Up @@ -257,7 +286,12 @@ def _get_state(self):
def reset(self, seed=None, **kwargs):
super().reset(seed=seed)

self.current_state = np.array([0, 0], dtype=np.int32)
if self.map_name == "convex" or self.map_name == "concave":
self.current_state = np.array([0, 0], dtype=np.int32)
elif self.map_name == "mirrored":
self.current_state = np.array([0, 10], dtype=np.int32)
else:
raise ValueError("Invalid map")
self.step_count = 0.0
state = self._get_state()
if self.render_mode == "human":
Expand Down

0 comments on commit 907b6ad

Please sign in to comment.