Skip to content

Commit

Permalink
erge branch 'ts/add_fruit_tree_env_render' of https://github.com/tome…
Browse files Browse the repository at this point in the history
…kster/MO-Gymnasium into ts/add_fruit_tree_env_render
  • Loading branch information
tomekster committed Jan 18, 2024
2 parents 6b25cb6 + c1026b6 commit d86a29f
Showing 1 changed file with 37 additions and 37 deletions.
74 changes: 37 additions & 37 deletions mo_gymnasium/envs/fruit_tree/fruit_tree.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Environment from https://github.com/RunzheYang/MORL/blob/master/synthetic/envs/fruit_tree.py
from typing import List
from os import path
from typing import List, Optional

import gymnasium as gym
import numpy as np
import pygame
from gymnasium import spaces
from gymnasium.utils import EzPickle
import pygame
from os import path
from typing import Optional


FRUITS = {
Expand Down Expand Up @@ -248,17 +247,21 @@ class FruitTreeEnv(gym.Env, EzPickle):
"""
## Description
Full binary tree of depth d=5,6 or 7. Every leaf contains a fruit with a value for the nutrients Protein, Carbs, Fats, Vitamins, Minerals and Water.
Full binary tree of depth d=5,6 or 7. Every leaf contains a fruit with
a value for the nutrients Protein, Carbs, Fats, Vitamins,
Minerals and Water.
From [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf).
## Observation Space
Discrete space of size 2^d-1, where d is the depth of the tree.
## Action Space
The agent can chose to go left or right at every node. The action space is therefore a discrete space of size 2.
The agent can chose to go left or right at every node.
The action space is therefore a discrete space of size 2.
## Reward Space
Each leaf node contains a 6-dimensional vector containing the nutrients of the fruit. The agent receives a reward for each nutrient it collects.
Each leaf node contains a 6-dimensional vector containing the nutrients of
the fruit. The agent receives a reward for each nutrient it collects.
## Starting State
The agent starts at the root node (0, 0).
Expand All @@ -267,8 +270,6 @@ class FruitTreeEnv(gym.Env, EzPickle):
The episode terminates when the agent reaches a leaf node.
"""

metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 1}

def __init__(self, depth=6, render_mode: Optional[str] = None):
assert depth in [5, 6, 7], "Depth must be 5, 6 or 7."
EzPickle.__init__(self, depth)
Expand All @@ -295,25 +296,27 @@ def __init__(self, depth=6, render_mode: Optional[str] = None):

# pygame
self.row_height = 20
self.window_size = (1200, self.row_height * self.tree_depth + 150) # add margin at the bottom to account for the node rewards
# Add margin at the bottom to account for the node rewards
self.window_size = (1200, self.row_height * self.tree_depth + 150)
self.node_square_size = np.array([10, 10], dtype=np.int32)
self.window_padding = 15 # padding on the left and right of the window
self.font_size = 12
self.window_padding = 15 # padding on the left and right of the window
self.font_size = 12
self.font = pygame.font.Font(None, self.font_size)

self.window = None
self.node_img = None
self.agent_img = None

def get_ind(self, pos):
""" Given the pos = current_state = [row_ind, pos_in_row]
"""Given the pos = current_state = [row_ind, pos_in_row]
return the index of the node in the tree array"""
return int(2 ** pos[0] - 1) + pos[1]

def ind_to_state(self, ind):
""" Given the index of the node in the tree array return the
"""Given the index of the node in the tree array return the
current_state = [row_ind, pos_in_row]"""
x = int(np.log2(ind + 1))
y = ind - 2 ** x + 1
y = ind - 2**x + 1
return np.array([x, y], dtype=np.int32)

def get_tree_value(self, pos):
Expand Down Expand Up @@ -350,13 +353,13 @@ def step(self, action):
reward = self.get_tree_value(self.current_state)
if self.current_state[0] == self.tree_depth:
self.terminal = True

return self.current_state.copy(), reward, self.terminal, False, {}

def get_pos_in_window(self, row, index_in_row):
""" Given the row and index_in_row of the node
"""Given the row and index_in_row of the node
calculate its position in the window in pixels"""
distance_between_nodes = (self.window_size[0] - 2 * self.window_padding) / (2 ** (row))
window_width = self.window_size[0] - 2 * self.window_padding
distance_between_nodes = window_width / (2 ** (row))
pos_x = self.window_padding + (index_in_row + 0.5) * distance_between_nodes
pos_y = row * self.row_height
return np.array([pos_x, pos_y])
Expand All @@ -365,7 +368,7 @@ def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying any render mode. "
"You are calling render method without specifying render mode."
"You can specify the render_mode at initialization, "
f'e.g. mo_gym.make("{self.spec.id}", render_mode="rgb_array")'
)
Expand All @@ -388,9 +391,6 @@ def render(self):
if self.agent_img is None:
filename = path.join(path.dirname(__file__), "assets", "agent.png")
self.agent_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size)

# self.font = pygame.font.Font(path.join(path.dirname(__file__), "assets", "Minecraft.ttf"), 20)
self.font = pygame.font.Font(None, self.font_size)

canvas = pygame.Surface(self.window_size)
canvas.fill((0, 0, 0))
Expand All @@ -399,13 +399,13 @@ def render(self):

for ind, node in enumerate(self.tree):
row, index_in_row = self.ind_to_state(ind)

if (row, index_in_row) == tuple(self.current_state):
img = self.agent_img
font_color = (255, 0, 0) # Red digits for agent node
font_color = (255, 0, 0) # Red digits for agent node
else:
img = self.node_img
font_color = (0, 255, 0) # Green digits for non-agent nodes
font_color = (0, 255, 0) # Green digits for non-agent nodes

node_pos = self.get_pos_in_window(row, index_in_row)

Expand All @@ -416,32 +416,32 @@ def render(self):
child1_pos = self.get_pos_in_window(row + 1, 2 * index_in_row)
child2_pos = self.get_pos_in_window(row + 1, 2 * index_in_row + 1)
half_square = self.node_square_size / 2
pygame.draw.line(self.window, (255,255,255), node_pos + half_square, child1_pos + half_square, 1)
pygame.draw.line(self.window, (255,255,255), node_pos + half_square, child2_pos + half_square, 1)
pygame.draw.line(self.window, (255, 255, 255), node_pos + half_square, child1_pos + half_square, 1)
pygame.draw.line(self.window, (255, 255, 255), node_pos + half_square, child2_pos + half_square, 1)
else:
# Print node values at the bottom of the tree
values_imgs = [self.font.render(f'{val:.2f}', True, font_color) for val in node]
values_imgs = [self.font.render(f"{val:.2f}", True, font_color) for val in node]
for i, val_img in enumerate(values_imgs):
self.window.blit(val_img, node_pos + np.array([- 5, (i+1)*self.font_size]))
self.window.blit(val_img, node_pos + np.array([-5, (i + 1) * self.font_size]))

if self.render_mode == "human":
pygame.event.pump()
pygame.display.update()
self.clock.tick(self.metadata["render_fps"])
elif self.render_mode == "rgb_array":
return np.transpose(np.array(pygame.surfarray.pixels3d(self.window)), axes=(1, 0, 2))



if __name__ == "__main__":
import mo_gymnasium as mo_gym
import time

import mo_gymnasium as mo_gym

env = mo_gym.make("fruit-tree", depth=6, render_mode="human")
terminated = False
env.reset()
while True:
env.render()
obs, r, terminated, truncated, info = env.step(env.action_space.sample())
if terminated or truncated:
obs, r, terminal, truncated, info = env.step(env.action_space.sample())
if terminal or truncated:
env.render()
time.sleep(2)
env.reset()

0 comments on commit d86a29f

Please sign in to comment.