diff --git a/momadm_benchmarks/envs/item_gathering/item_gathering.py b/momadm_benchmarks/envs/item_gathering/item_gathering.py index 28c7a4ec..a567fc79 100644 --- a/momadm_benchmarks/envs/item_gathering/item_gathering.py +++ b/momadm_benchmarks/envs/item_gathering/item_gathering.py @@ -15,15 +15,18 @@ import functools import random from copy import deepcopy +from os import path # from gymnasium.utils import EzPickle from typing_extensions import override import numpy as np +import pygame from gymnasium.logger import warn from gymnasium.spaces import Box, Discrete from pettingzoo.utils import wrappers +from momadm_benchmarks.envs.item_gathering.asset_utils import del_colored, get_colored from momadm_benchmarks.envs.item_gathering.map_utils import DEFAULT_MAP, randomise_map from momadm_benchmarks.utils.conversions import mo_parallel_to_aec from momadm_benchmarks.utils.env import MOParallelEnv @@ -75,7 +78,7 @@ class MOItemGathering(MOParallelEnv): These attributes should not be changed after initialization. """ - metadata = {"render_modes": ["human"], "name": "moitemgathering_v0"} + metadata = {"render_modes": ["human", "rgb_array"], "name": "moitemgathering_v0", "render_fps": 50} def __init__( self, @@ -117,7 +120,6 @@ def __init__( self.terminations = {agent: False for agent in self.agents} self.truncations = {agent: False for agent in self.agents} self.action_spaces = dict(zip(self.agents, [Discrete(len(ACTIONS))] * len(self.agent_positions))) - print(self.action_spaces) # observation space is a 2D array, the same size as the grid # 0 for empty, 1 for the current agent, 2 for other agents, 3 for objective 1, 4 for objective 2, ... @@ -141,6 +143,8 @@ def __init__( self.item_dict = {} for i, item in enumerate(all_map_entries[0][np.where(all_map_entries[0] > 2)]): self.item_dict[item] = i + + print(self.item_dict) indices_of_items = np.argwhere(all_map_entries[0] > 2).flatten() item_counts = np.take(all_map_entries[1], indices_of_items) self.num_objectives = len(item_counts) @@ -151,6 +155,19 @@ def __init__( zip(self.agents, [Box(low=0, high=max(item_counts), shape=(self.num_objectives,))] * len(self.agent_positions)) ) + # pygame + self.size = 5 + self.cell_size = (64, 64) + self.window_size = ( + self.env_map.shape[1] * self.cell_size[1], + self.env_map.shape[0] * self.cell_size[0], + ) + self.clock = None + self.agent_imgs = [] + self.item_imgs = [] + self.map_bg_imgs = [] + self.window = None + # this cache ensures that same space object is returned for the same agent # allows action space seeding to work as expected @functools.lru_cache(maxsize=None) @@ -173,9 +190,76 @@ def render(self): if self.render_mode is None: warn("You are calling render method without specifying any render mode.") return + if self.window is None: + pygame.init() + get_colored("item", len(self.item_dict)) + get_colored("agent", len(self.agents)) + + if self.render_mode == "human": + pygame.display.init() + pygame.display.set_caption("Item Gathering") + self.window = pygame.display.set_mode(self.window_size) + else: + self.window = pygame.Surface(self.window_size) + + if self.clock is None: + self.clock = pygame.time.Clock() + + if not self.agent_imgs: + agents = [path.join(path.dirname(__file__), "assets/agent.png")] + for i in range(len(self.agents) - 1): + agents.append(path.join(path.dirname(__file__), f"assets/colored/agent{i}.png")) + self.agent_imgs = [pygame.transform.scale(pygame.image.load(f_name), self.cell_size) for f_name in agents] + if not self.map_bg_imgs: + map_bg_imgs = [ + path.join(path.dirname(__file__), "assets/map_bg1.png"), + path.join(path.dirname(__file__), "assets/map_bg2.png"), + ] + self.map_bg_imgs = [ + pygame.transform.scale(pygame.image.load(f_name), self.cell_size) for f_name in map_bg_imgs + ] + if not self.item_imgs: + items = [path.join(path.dirname(__file__), "assets/item.png")] + for i in range(len(self.item_dict) - 1): + items.append(path.join(path.dirname(__file__), f"assets/colored/item{i}.png")) + self.item_imgs = [ + pygame.transform.scale(pygame.image.load(f_name), (0.6 * self.cell_size[0], 0.6 * self.cell_size[1])) + for f_name in items + ] + + for i in range(self.env_map.shape[0]): + for j in range(self.env_map.shape[1]): + # background + check_board_mask = i % 2 ^ j % 2 + self.window.blit( + self.map_bg_imgs[check_board_mask], + np.array([j, i]) * self.cell_size[0], + ) + + # agents + if [i, j] in self.agent_positions.tolist(): + ind = self.agent_positions.tolist().index([i, j]) + self.window.blit(self.agent_imgs[ind], tuple(np.array([j, i]) * self.cell_size[0])) + + # items + elif int(self.env_map[i, j]) > 2: + ind = ( + int(self.env_map[i, j]) - list(self.item_dict.keys())[0] + ) # item n will have will have 0th (n-n) index + self.window.blit(self.item_imgs[ind], tuple(np.array([j + 0.22, i + 0.25]) * self.cell_size[0])) + + if self.render_mode == "human": + pygame.event.pump() + pygame.display.update() + self.clock.tick(self.metadata["render_fps"]) + elif self.render_mode == "rgb_array": # rgb_array + return np.transpose(np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2)) @override def close(self): + if self.render_mode is not None: + del_colored("item", len(self.item_dict)) + del_colored("agent", len(self.possible_agents)) pass @override