Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(wrh): taxi_dqn_config.py update #802

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp) <br> env tutorial <br> 环境指南 |
| 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake) <br> env tutorial <br> 环境指南 |
| 39 | [ising_model](https://github.com/mlii/mfrl/tree/master/examples/ising_model) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/ising_env/ising_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/ising_env) <br> env tutorial <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/ising_model_zh.html) |
| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/taxi) <br> [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/taxi.html) <br> [环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/taxi_zh.html) |



![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space

Expand Down
Binary file added dizoo/taxi/Taxi-v3_episode_0.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions dizoo/taxi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .envs import *
1 change: 1 addition & 0 deletions dizoo/taxi/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .taxi_dqn_config import main_config, create_config
64 changes: 64 additions & 0 deletions dizoo/taxi/config/taxi_dqn_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from easydict import EasyDict

taxi_dqn_config = dict(
exp_name='taxi_dqn_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
max_episode_steps=60,
env_id="Taxi-v3"
),
policy=dict(
cuda=True,
model=dict(
obs_shape=34,
action_shape=6,
encoder_hidden_size_list=[128, 128]
),
random_collect_size=5000,
nstep=3,
discount_factor=0.99,
learn=dict(
update_per_collect=10,
batch_size=64,
learning_rate=0.0001,
learner=dict(
hook=dict(
log_show_after_iter=1000,
)
),
),
collect=dict(n_sample=32),
eval=dict(evaluator=dict(eval_freq=1000, )),
other=dict(
eps=dict(
type="linear",
start=1,
end=0.05,
decay=3000000
),
replay_buffer=dict(replay_buffer_size=100000,),
),
)
)
taxi_dqn_config = EasyDict(taxi_dqn_config)
main_config = taxi_dqn_config

taxi_dqn_create_config = dict(
env=dict(
type="taxi",
import_names=["dizoo.taxi.envs.taxi_env"]
),
env_manager=dict(type='base'),
policy=dict(type='dqn'),
replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
)

taxi_dqn_create_config = EasyDict(taxi_dqn_create_config)
create_config = taxi_dqn_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), max_env_step=3000000, seed=20)
38 changes: 38 additions & 0 deletions dizoo/taxi/entry/taxi_dqn_deploy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import gym
import torch
from easydict import EasyDict

from ding.config import compile_config
from ding.envs import DingEnvWrapper
from ding.model import DQN
from ding.policy import DQNPolicy, single_env_forward_wrapper
from dizoo.taxi.config.taxi_dqn_config import create_config, main_config
from dizoo.taxi.envs.taxi_env import TaxiEnv

def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str) -> None:
main_config.exp_name = f'taxi_dqn_seed0_deploy'
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
env = TaxiEnv(cfg.env)
env.enable_save_replay(replay_path=f'./{main_config.exp_name}/video')
model = DQN(**cfg.policy.model)
state_dict = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(state_dict['model'])
policy = DQNPolicy(cfg.policy, model=model).eval_mode
forward_fn = single_env_forward_wrapper(policy.forward)
obs = env.reset()
returns = 0.
while True:
action = forward_fn(obs)
obs, rew, done, info = env.step(action)
returns += rew
if done:
break
print(f'Deploy is finished, final epsiode return is: {returns}')


if __name__ == "__main__":
main(
main_config=main_config,
create_config=create_config,
ckpt_path=f'./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar'
)
1 change: 1 addition & 0 deletions dizoo/taxi/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .taxi_env import TaxiEnv
168 changes: 168 additions & 0 deletions dizoo/taxi/envs/taxi_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import List, Optional
import os

from easydict import EasyDict
from gym.spaces import Space, Discrete
from gym.spaces.box import Box
import gym
import numpy as np
import imageio

from ditk import logging
from ding.envs.env.base_env import BaseEnv, BaseEnvTimestep
from ding.torch_utils import to_ndarray
from ding.utils import ENV_REGISTRY

@ENV_REGISTRY.register('taxi', force_overwrite=True)
class TaxiEnv(BaseEnv):

def __init__(self, cfg: EasyDict) -> None:

self._cfg = cfg
assert self._cfg.env_id == "Taxi-v3", "Your environment name is not Taxi-v3!"
self._init_flag = False
self._replay_path = None
self._save_replay = False
self._frames = []

def reset(self) -> np.ndarray:
if not self._init_flag:
self._env = gym.make(
id=self._cfg.env_id,
render_mode="single_rgb_array",
max_episode_steps=self._cfg.max_episode_steps
)
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._reward_space = Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)
self._init_flag = True
self._eval_episode_return = 0
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env_seed = self._seed + np_seed
elif hasattr(self, '_seed'):
self._env_seed = self._seed
if hasattr(self, '_seed'):
obs = self._env.reset(seed=self._env_seed)
else:
obs = self._env.reset()

if self._save_replay:
picture = self._env.render()
self._frames.append(picture)
self._eval_episode_return = 0.
obs = self._encode_taxi(obs).astype(np.float32)
return obs

def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)

def step(self, action: np.ndarray) -> BaseEnvTimestep:
assert isinstance(action, np.ndarray), type(action)
action = action.item()
obs, rew, done, info = self._env.step(action)
self._eval_episode_return += rew
obs = self._encode_taxi(obs)
rew = to_ndarray([rew]) # Transformed to an array with shape (1, )
if self._save_replay:
picture = self._env.render()
self._frames.append(picture)
if done:
info['eval_episode_return'] = self._eval_episode_return
if self._save_replay:
assert self._replay_path is not None, "your should have a path"
path = os.path.join(
self._replay_path, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count)
)
self.frames_to_gif(self._frames, path)
self._frames = []
self._save_replay_count += 1
rew = rew.astype(np.float32)
obs = obs.astype(np.float32)
return BaseEnvTimestep(obs, rew, done, info)

def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
if not os.path.exists(replay_path):
os.makedirs(replay_path)
self._replay_path = replay_path
self._save_replay = True
self._save_replay_count = 0

def random_action(self) -> np.ndarray:
random_action = self.action_space.sample()
if isinstance(random_action, np.ndarray):
pass
elif isinstance(random_action, int):
random_action = to_ndarray([random_action], dtype=np.int64)
elif isinstance(random_action, dict):
random_action = to_ndarray(random_action)
else:
raise TypeError(
'`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format(
type(random_action), random_action
)
)
return random_action

#todo encode the state into a vector
def _encode_taxi(self, obs: np.ndarray) -> np.ndarray:
taxi_row, taxi_col, passenger_location, destination = self._env.unwrapped.decode(obs)
encoded_obs = np.zeros(34)
encoded_obs[5 * taxi_row + taxi_col] = 1
encoded_obs[25 + passenger_location] = 1
encoded_obs[30 + destination] = 1
return to_ndarray(encoded_obs)

@property
def observation_space(self) -> Space:
return self._observation_space

@property
def action_space(self) -> Space:
return self._action_space

@property
def reward_space(self) -> Space:
return self._reward_space

def __repr__(self) -> str:
return "DI-engine Taxi-v3 Env"

@staticmethod
def frames_to_gif(frames: List[imageio.core.util.Array], gif_path: str, duration: float = 0.1) -> None:
"""
Overview:
Convert a list of frames into a GIF.
Arguments:
- frames (:obj:`List[imageio.core.util.Array]`): A list of frames, each frame is an image.
- gif_path (:obj:`str`): The path to save the GIF file.
- duration (:obj:`float`): Duration between each frame in the GIF (seconds).
"""
# Save all frames as temporary image files
temp_image_files = []
for i, frame in enumerate(frames):
temp_image_file = f"frame_{i}.png" # Temporary file name
imageio.imwrite(temp_image_file, frame) # Save the frame as a PNG file
temp_image_files.append(temp_image_file)

# Use imageio to convert temporary image files to GIF
with imageio.get_writer(gif_path, mode='I', duration=duration) as writer:
for temp_image_file in temp_image_files:
image = imageio.imread(temp_image_file)
writer.append_data(image)

# Clean up temporary image files
for temp_image_file in temp_image_files:
os.remove(temp_image_file)
logging.info(f"GIF saved as {gif_path}")
41 changes: 41 additions & 0 deletions dizoo/taxi/envs/test_taxi_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
import pytest
from easydict import EasyDict
from dizoo.taxi import TaxiEnv

@pytest.mark.envtest
class TestTaxiEnv:

def test_naive(self):
env = TaxiEnv(
EasyDict({
"env_id": "Taxi-v3",
"max_episode_steps": 300
})
)
env.seed(314, dynamic_seed=False)
assert env._seed == 314
obs = env.reset()
assert obs.shape == (34, )
for _ in range(5):
env.reset()
np.random.seed(314)
print('=' * 60)
for i in range(10):
# Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
# can generate legal random action.
if i < 5:
random_action = np.array([env.action_space.sample()])
else:
random_action = env.random_action()
timestep = env.step(random_action)
print(f"Your timestep in wrapped mode is: {timestep}")
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (34, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high
print(env.observation_space, env.action_space, env.reward_space)
env.close()