Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(wrh): add taxi env latest version and dqn config #807

Merged
merged 2 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp) <br> env tutorial <br> 环境指南 |
| 38 | [frozen_lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/frozen_lake/FrozenLake.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/frozen_lake) <br> env tutorial <br> 环境指南 |
| 39 | [ising_model](https://github.com/mlii/mfrl/tree/master/examples/ising_model) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/ising_env/ising_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/ising_env) <br> env tutorial <br> [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/ising_model_zh.html) |
| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | dizoo link <br> env tutorial <br> 环境指南 |
| 40 | [taxi](https://www.gymlibrary.dev/environments/toy_text/taxi/) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/taxi/Taxi-v3_episode_0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/taxi/envs) <br> [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/taxi.html) <br> [环境指南](https://di-engine-docs.readthedocs.io/zh-cn/latest/13_envs/taxi_zh.html) |



Expand Down
46 changes: 26 additions & 20 deletions dizoo/taxi/config/taxi_dqn_config.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,45 @@
from easydict import EasyDict

taxi_dqn_config = dict(
exp_name='taxi_seed0',
exp_name='taxi_dqn_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=8,
n_evaluator_episode=10,
max_episode_steps=300,
env_id="Taxi-v3"
evaluator_env_num=8,
n_evaluator_episode=8,
stop_value=20,
max_episode_steps=60,
env_id="Taxi-v3"
),
policy=dict(
cuda=True,
load_path="./taxi_dqn_seed0/ckpt/ckpt_best.pth.tar",
model=dict(
obs_shape=4,
obs_shape=34,
action_shape=6,
encoder_hidden_size_list=[256, 128, 64]
encoder_hidden_size_list=[128, 128]
),
random_collect_size=5000,
nstep=3,
discount_factor=0.98,
discount_factor=0.99,
learn=dict(
update_per_collect=5,
batch_size=128,
learning_rate=0.001,
update_per_collect=10,
batch_size=64,
learning_rate=0.0001,
learner=dict(
hook=dict(
log_show_after_iter=1000,
)
),
),
collect=dict(n_sample=10),
eval=dict(evaluator=dict(eval_freq=5, )),
collect=dict(n_sample=32),
eval=dict(evaluator=dict(eval_freq=1000, )),
other=dict(
eps=dict(
type="linear",
start=0.8,
end=0.1,
decay=10000
),
replay_buffer=dict(replay_buffer_size=20000,),
start=1,
end=0.05,
decay=3000000
),
replay_buffer=dict(replay_buffer_size=100000,),
),
)
)
Expand All @@ -55,4 +61,4 @@

if __name__ == "__main__":
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), max_env_step=5000, seed=0)
serial_pipeline((main_config, create_config), max_env_step=3000000, seed=0)
10 changes: 7 additions & 3 deletions dizoo/taxi/envs/taxi_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep:
def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
if not os.path.exists(replay_path):
os.makedirs(replay_path)
if not os.path.exists(replay_path):
os.makedirs(replay_path)
self._replay_path = replay_path
self._save_replay = True
self._save_replay_count = 0
Expand All @@ -118,7 +118,11 @@ def random_action(self) -> np.ndarray:
#todo encode the state into a vector
def _encode_taxi(self, obs: np.ndarray) -> np.ndarray:
taxi_row, taxi_col, passenger_location, destination = self._env.unwrapped.decode(obs)
return to_ndarray([taxi_row, taxi_col, passenger_location, destination])
encoded_obs = np.zeros(34)
encoded_obs[5 * taxi_row + taxi_col] = 1
encoded_obs[25 + passenger_location] = 1
encoded_obs[30 + destination] = 1
return to_ndarray(encoded_obs)

@property
def observation_space(self) -> Space:
Expand Down
4 changes: 2 additions & 2 deletions dizoo/taxi/envs/test_taxi_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_naive(self):
env.seed(314, dynamic_seed=False)
assert env._seed == 314
obs = env.reset()
assert obs.shape == (4, )
assert obs.shape == (34, )
for _ in range(5):
env.reset()
np.random.seed(314)
Expand All @@ -32,7 +32,7 @@ def test_naive(self):
print(f"Your timestep in wrapped mode is: {timestep}")
assert isinstance(timestep.obs, np.ndarray)
assert isinstance(timestep.done, bool)
assert timestep.obs.shape == (4, )
assert timestep.obs.shape == (34, )
assert timestep.reward.shape == (1, )
assert timestep.reward >= env.reward_space.low
assert timestep.reward <= env.reward_space.high
Expand Down
Loading