Skip to content

Commit

Permalink
Add rendering for fruit_tree env (#81)
Browse files Browse the repository at this point in the history
* Add drowing lines, change color to green

* Add drowing lines, change color to green

* Add pygame font init

* Initialize font

* Empty-Commit

* Add a margin at the top of the window

* Add a margin at the top of the window

* Update pyright version

* Remove unused node.png, revert docstring

* Run gen_gifs script

* Add fruit-tree.gif
  • Loading branch information
tomekster authored Jan 28, 2024
1 parent f5e87cf commit 2bf89bf
Show file tree
Hide file tree
Showing 21 changed files with 151 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
hooks:
- id: flake8
args:
- '--per-file-ignores=*/__init__.py:F401'
- "--per-file-ignores=*/__init__.py:F401"
- --ignore=E203,W503,E741
- --max-complexity=30
- --max-line-length=456
Expand Down Expand Up @@ -64,6 +64,6 @@ repos:
language: node
pass_filenames: false
types: [python]
additional_dependencies: ["pyright"]
additional_dependencies: ["pyright@1.1.347"]
args:
- --project=pyproject.toml
Binary file modified docs/_static/videos/breakable-bottles.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/deep-sea-treasure-concave.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/deep-sea-treasure-mirrored.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/deep-sea-treasure.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/four-room.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/videos/fruit-tree.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/minecart-deterministic.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/minecart.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-halfcheetah.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-hopper.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-lunar-lander.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-mountaincar.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-mountaincarcontinuous.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-reacher.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-supermario.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/resource-gathering.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/water-reservoir.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added mo_gymnasium/envs/fruit_tree/assets/agent.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added mo_gymnasium/envs/fruit_tree/assets/node_blue.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
155 changes: 149 additions & 6 deletions mo_gymnasium/envs/fruit_tree/fruit_tree.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Environment from https://github.com/RunzheYang/MORL/blob/master/synthetic/envs/fruit_tree.py
from typing import List
from os import path
from typing import List, Optional

import gymnasium as gym
import numpy as np
import pygame
from gymnasium import spaces
from gymnasium.utils import EzPickle

Expand Down Expand Up @@ -264,16 +266,16 @@ class FruitTreeEnv(gym.Env, EzPickle):
The episode terminates when the agent reaches a leaf node.
"""

def __init__(self, depth=6):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

def __init__(self, depth=6, render_mode: Optional[str] = None):
assert depth in [5, 6, 7], "Depth must be 5, 6 or 7."
EzPickle.__init__(self, depth)

self.render_mode = render_mode
self.reward_dim = 6
self.tree_depth = depth # zero based depth
branches = np.zeros((int(2**self.tree_depth - 1), self.reward_dim))
# fruits = np.random.randn(2**self.tree_depth, self.reward_dim)
# fruits = np.abs(fruits) / np.linalg.norm(fruits, 2, 1, True)
# print(fruits*10)
fruits = np.array(FRUITS[str(depth)])
self.tree = np.concatenate([branches, fruits])

Expand All @@ -288,9 +290,35 @@ def __init__(self, depth=6):
self.current_state = np.array([0, 0], dtype=np.int32)
self.terminal = False

# pygame
self.row_height = 20
self.top_margin = 15

# Add margin at the bottom to account for the node rewards
self.window_size = (1200, self.row_height * self.tree_depth + 150)
self.window_padding = 15 # padding on the left and right of the window
self.node_square_size = np.array([10, 10], dtype=np.int32)
self.font_size = 12
pygame.font.init()
self.font = pygame.font.SysFont(None, self.font_size)

self.window = None
self.clock = None
self.node_img = None
self.agent_img = None

def get_ind(self, pos):
"""Given the pos = current_state = [row_ind, pos_in_row]
return the index of the node in the tree array"""
return int(2 ** pos[0] - 1) + pos[1]

def ind_to_state(self, ind):
"""Given the index of the node in the tree array return the
current_state = [row_ind, pos_in_row]"""
x = int(np.log2(ind + 1))
y = ind - 2**x + 1
return np.array([x, y], dtype=np.int32)

def get_tree_value(self, pos):
return np.array(self.tree[self.get_ind(pos)], dtype=np.float32)

Expand Down Expand Up @@ -325,5 +353,120 @@ def step(self, action):
reward = self.get_tree_value(self.current_state)
if self.current_state[0] == self.tree_depth:
self.terminal = True

return self.current_state.copy(), reward, self.terminal, False, {}

def get_pos_in_window(self, row, index_in_row):
"""Given the row and index_in_row of the node
calculate its position in the window in pixels"""
window_width = self.window_size[0] - 2 * self.window_padding
distance_between_nodes = window_width / (2 ** (row))
pos_x = self.window_padding + (index_in_row + 0.5) * distance_between_nodes
pos_y = row * self.row_height
return np.array([pos_x, pos_y])

def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying render mode."
"You can specify the render_mode at initialization, "
f'e.g. mo_gym.make("{self.spec.id}", render_mode="rgb_array")'
)
return

if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()

if self.window is None:
pygame.init()

if self.render_mode == "human":
pygame.display.init()
pygame.display.set_caption("Fruit Tree")
self.window = pygame.display.set_mode(self.window_size)
self.clock.tick(self.metadata["render_fps"])
else:
self.window = pygame.Surface(self.window_size)

if self.node_img is None:
filename = path.join(path.dirname(__file__), "assets", "node_blue.png")
self.node_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size)
self.node_img = pygame.transform.flip(self.node_img, flip_x=True, flip_y=False)

if self.agent_img is None:
filename = path.join(path.dirname(__file__), "assets", "agent.png")
self.agent_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size)

canvas = pygame.Surface(self.window_size)
canvas.fill((255, 255, 255)) # White

# draw branches
for ind, node in enumerate(self.tree):
row, index_in_row = self.ind_to_state(ind)
node_pos = self.get_pos_in_window(row, index_in_row)
if row < self.tree_depth:
# Get childerns' positions and draw branches
child1_pos = self.get_pos_in_window(row + 1, 2 * index_in_row)
child2_pos = self.get_pos_in_window(row + 1, 2 * index_in_row + 1)
half_square = self.node_square_size / 2
pygame.draw.line(canvas, (90, 82, 85), node_pos + half_square, child1_pos + half_square, 1)
pygame.draw.line(canvas, (90, 82, 85), node_pos + half_square, child2_pos + half_square, 1)

for ind, node in enumerate(self.tree):
row, index_in_row = self.ind_to_state(ind)
if (row, index_in_row) == tuple(self.current_state):
img = self.agent_img
font_color = (164, 0, 0) # Red digits for agent node
else:
img = self.node_img
if ind % 2:
font_color = (250, 128, 114) # Green
else:
font_color = (45, 72, 101) # Dark Blue

node_pos = self.get_pos_in_window(row, index_in_row)

canvas.blit(img, np.array(node_pos))

# Print node values at the bottom of the tree
if row == self.tree_depth:
odd_nodes_values_offset = 0.5 * (ind % 2)
values_imgs = [self.font.render(f"{val:.2f}", True, font_color) for val in node]
for i, val_img in enumerate(values_imgs):
canvas.blit(val_img, node_pos + np.array([-5, (i + 1 + odd_nodes_values_offset) * 1.5 * self.font_size]))

background = pygame.Surface(self.window_size)
background.fill((255, 255, 255)) # White
background.blit(canvas, (0, self.top_margin))

self.window.blit(background, (0, 0))

if self.render_mode == "human":
pygame.event.pump()
pygame.display.update()
self.clock.tick(self.metadata["render_fps"])
elif self.render_mode == "rgb_array":
return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))

background = pygame.Surface(self.window_size)
background.fill((255, 255, 255)) # White

background.blit(canvas, (0, self.top_margin))

self.window.blit(background, (0, 0))


if __name__ == "__main__":
import time

import mo_gymnasium as mo_gym

env = mo_gym.make("fruit-tree", depth=6, render_mode="human")
env.reset()
while True:
env.render()
obs, r, terminal, truncated, info = env.step(env.action_space.sample())
if terminal or truncated:
env.render()
time.sleep(2)
env.reset()

0 comments on commit 2bf89bf

Please sign in to comment.