Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polish(nyz): polish atari envpool demo #702

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions ding/entry/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Callable, List, Any
from typing import Optional, Callable, List, Dict, Any

from ding.policy import PolicyFactory
from ding.worker import IMetric, MetricSerialEvaluator
Expand Down Expand Up @@ -46,7 +46,8 @@ def random_collect(
collector_env: 'BaseEnvManager', # noqa
commander: 'BaseSerialCommander', # noqa
replay_buffer: 'IBuffer', # noqa
postprocess_data_fn: Optional[Callable] = None
postprocess_data_fn: Optional[Callable] = None,
collect_kwargs: Optional[Dict] = None,
) -> None: # noqa
assert policy_cfg.random_collect_size > 0
if policy_cfg.get('transition_with_policy_data', False):
Expand All @@ -55,7 +56,8 @@ def random_collect(
action_space = collector_env.action_space
random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
collector.reset_policy(random_policy)
collect_kwargs = commander.step()
if collect_kwargs is None:
collect_kwargs = commander.step()
if policy_cfg.collect.collector.type == 'episode':
new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs)
else:
Expand Down
4 changes: 1 addition & 3 deletions ding/envs/env_manager/envpool_env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,8 @@ def launch(self) -> None:
seed=seed,
episodic_life=self._cfg.episodic_life,
reward_clip=self._cfg.reward_clip,
stack_num=self._cfg.stack_num,
gray_scale=self._cfg.gray_scale,
frame_skip=self._cfg.frame_skip
)
self.action_space = self._envs.action_space
self._closed = False
self.reset()

Expand Down
2 changes: 1 addition & 1 deletion ding/worker/collector/sample_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class SampleSerialCollector(ISerialCollector):
envstep
"""

config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100)
config = dict(type='sample', deepcopy_obs=False, transform_obs=False, collect_print_freq=100)

def __init__(
self,
Expand Down
11 changes: 3 additions & 8 deletions dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
evaluator_batch_size=8,
n_evaluator_episode=8,
stop_value=20,
env_id='PongNoFrameskip-v4',
#'ALE/Pong-v5' is available. But special setting is needed after gym make.
frame_stack=4,
env_id='Pong-v5',
),
policy=dict(
cuda=True,
priority=False,
random_collect_size=5000,
model=dict(
obs_shape=[4, 84, 84],
action_shape=6,
Expand Down Expand Up @@ -55,8 +54,4 @@
)
pong_dqn_envpool_create_config = EasyDict(pong_dqn_envpool_create_config)
create_config = pong_dqn_envpool_create_config

if __name__ == '__main__':
# or you can enter `ding -m serial -c pong_dqn_envpool_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0)
# You can use `dizoo/atari/entry/atari_dqn_envpool_main.py` to run this config.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
from ding.envs.env_manager.envpool_env_manager import PoolEnvManager
from ding.policy import DQNPolicy
from ding.entry import random_collect
from ding.model import DQN
from ding.utils import set_pkg_seed
from ding.rl_utils import get_epsilon_greedy_fn
from dizoo.atari.config.serial import pong_dqn_envpool_config


def main(cfg, seed=0, max_iterations=int(1e10)):
cfg.exp_name = 'atari_dqn_envpool'
def main(cfg, seed=2, max_iterations=int(1e10)):
cfg.exp_name = 'atari_dqn_envpool_pong_two_true_seed2'
cfg = compile_config(
cfg,
PoolEnvManager,
Expand All @@ -32,9 +33,6 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
# env wrappers
'episodic_life': True, # collector: True
'reward_clip': True, # collector: True
'gray_scale': cfg.env.get('gray_scale', True),
'stack_num': cfg.env.get('stack_num', 4),
'frame_skip': cfg.env.get('frame_skip', 4),
}
)
collector_env = PoolEnvManager(collector_env_cfg)
Expand All @@ -46,9 +44,6 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
# env wrappers
'episodic_life': False, # evaluator: False
'reward_clip': False, # evaluator: False
'gray_scale': cfg.env.get('gray_scale', True),
'stack_num': cfg.env.get('stack_num', 4),
'frame_skip': cfg.env.get('frame_skip', 4),
}
)
evaluator_env = PoolEnvManager(evaluator_env_cfg)
Expand All @@ -72,6 +67,9 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
eps_cfg = cfg.policy.other.eps
epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)

if cfg.policy.random_collect_size > 0:
collect_kwargs = {'eps': epsilon_greedy(collector.envstep)}
random_collect(cfg.policy, policy, collector, collector_env, {}, replay_buffer, collect_kwargs=collect_kwargs)
while True:
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
Expand All @@ -85,6 +83,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)):
train_data = replay_buffer.sample(batch_size, learner.train_iter)
if train_data is not None:
learner.train(train_data, collector.envstep)
if collector.envstep >= int(3e6):
break


if __name__ == "__main__":
Expand Down
74 changes: 0 additions & 74 deletions dizoo/atari/entry/phoenix_fqf_main.py

This file was deleted.

74 changes: 0 additions & 74 deletions dizoo/atari/entry/phoenix_iqn_main.py

This file was deleted.

74 changes: 0 additions & 74 deletions dizoo/atari/entry/pong_fqf_main.py

This file was deleted.

Loading