Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

human + rgb_array rendering modes #12

Merged
merged 3 commits into from
Nov 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions momadm_benchmarks/envs/item_gathering/item_gathering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@
import functools
import random
from copy import deepcopy
from os import path

# from gymnasium.utils import EzPickle
from typing_extensions import override

import numpy as np
import pygame
from gymnasium.logger import warn
from gymnasium.spaces import Box, Discrete
from pettingzoo.utils import wrappers

from momadm_benchmarks.envs.item_gathering.asset_utils import del_colored, get_colored
from momadm_benchmarks.envs.item_gathering.map_utils import DEFAULT_MAP, randomise_map
from momadm_benchmarks.utils.conversions import mo_parallel_to_aec
from momadm_benchmarks.utils.env import MOParallelEnv
Expand Down Expand Up @@ -75,7 +78,7 @@ class MOItemGathering(MOParallelEnv):
These attributes should not be changed after initialization.
"""

metadata = {"render_modes": ["human"], "name": "moitemgathering_v0"}
metadata = {"render_modes": ["human", "rgb_array"], "name": "moitemgathering_v0", "render_fps": 50}

def __init__(
self,
Expand Down Expand Up @@ -117,7 +120,6 @@ def __init__(
self.terminations = {agent: False for agent in self.agents}
self.truncations = {agent: False for agent in self.agents}
self.action_spaces = dict(zip(self.agents, [Discrete(len(ACTIONS))] * len(self.agent_positions)))
print(self.action_spaces)

# observation space is a 2D array, the same size as the grid
# 0 for empty, 1 for the current agent, 2 for other agents, 3 for objective 1, 4 for objective 2, ...
Expand All @@ -141,6 +143,8 @@ def __init__(
self.item_dict = {}
for i, item in enumerate(all_map_entries[0][np.where(all_map_entries[0] > 2)]):
self.item_dict[item] = i

print(self.item_dict)
indices_of_items = np.argwhere(all_map_entries[0] > 2).flatten()
item_counts = np.take(all_map_entries[1], indices_of_items)
self.num_objectives = len(item_counts)
Expand All @@ -151,6 +155,19 @@ def __init__(
zip(self.agents, [Box(low=0, high=max(item_counts), shape=(self.num_objectives,))] * len(self.agent_positions))
)

# pygame
self.size = 5
self.cell_size = (64, 64)
self.window_size = (
self.env_map.shape[1] * self.cell_size[1],
self.env_map.shape[0] * self.cell_size[0],
)
self.clock = None
self.agent_imgs = []
self.item_imgs = []
self.map_bg_imgs = []
self.window = None

# this cache ensures that same space object is returned for the same agent
# allows action space seeding to work as expected
@functools.lru_cache(maxsize=None)
Expand All @@ -173,9 +190,76 @@ def render(self):
if self.render_mode is None:
warn("You are calling render method without specifying any render mode.")
return
if self.window is None:
pygame.init()
get_colored("item", len(self.item_dict))
get_colored("agent", len(self.agents))

if self.render_mode == "human":
pygame.display.init()
pygame.display.set_caption("Item Gathering")
self.window = pygame.display.set_mode(self.window_size)
else:
self.window = pygame.Surface(self.window_size)

if self.clock is None:
self.clock = pygame.time.Clock()

if not self.agent_imgs:
agents = [path.join(path.dirname(__file__), "assets/agent.png")]
for i in range(len(self.agents) - 1):
agents.append(path.join(path.dirname(__file__), f"assets/colored/agent{i}.png"))
self.agent_imgs = [pygame.transform.scale(pygame.image.load(f_name), self.cell_size) for f_name in agents]
if not self.map_bg_imgs:
map_bg_imgs = [
path.join(path.dirname(__file__), "assets/map_bg1.png"),
path.join(path.dirname(__file__), "assets/map_bg2.png"),
]
self.map_bg_imgs = [
pygame.transform.scale(pygame.image.load(f_name), self.cell_size) for f_name in map_bg_imgs
]
if not self.item_imgs:
items = [path.join(path.dirname(__file__), "assets/item.png")]
for i in range(len(self.item_dict) - 1):
items.append(path.join(path.dirname(__file__), f"assets/colored/item{i}.png"))
self.item_imgs = [
pygame.transform.scale(pygame.image.load(f_name), (0.6 * self.cell_size[0], 0.6 * self.cell_size[1]))
for f_name in items
]

for i in range(self.env_map.shape[0]):
for j in range(self.env_map.shape[1]):
# background
check_board_mask = i % 2 ^ j % 2
self.window.blit(
self.map_bg_imgs[check_board_mask],
np.array([j, i]) * self.cell_size[0],
)

# agents
if [i, j] in self.agent_positions.tolist():
ind = self.agent_positions.tolist().index([i, j])
self.window.blit(self.agent_imgs[ind], tuple(np.array([j, i]) * self.cell_size[0]))

# items
elif int(self.env_map[i, j]) > 2:
ind = (
int(self.env_map[i, j]) - list(self.item_dict.keys())[0]
) # item n will have will have 0th (n-n) index
self.window.blit(self.item_imgs[ind], tuple(np.array([j + 0.22, i + 0.25]) * self.cell_size[0]))

if self.render_mode == "human":
pygame.event.pump()
pygame.display.update()
self.clock.tick(self.metadata["render_fps"])
elif self.render_mode == "rgb_array": # rgb_array
return np.transpose(np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2))

@override
def close(self):
if self.render_mode is not None:
del_colored("item", len(self.item_dict))
del_colored("agent", len(self.possible_agents))
pass

@override
Expand Down
Loading