diff --git a/ding/entry/utils.py b/ding/entry/utils.py index bbfbaa83bd..d02715eff3 100644 --- a/ding/entry/utils.py +++ b/ding/entry/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Callable, List, Any +from typing import Optional, Callable, List, Dict, Any from ding.policy import PolicyFactory from ding.worker import IMetric, MetricSerialEvaluator @@ -46,7 +46,8 @@ def random_collect( collector_env: 'BaseEnvManager', # noqa commander: 'BaseSerialCommander', # noqa replay_buffer: 'IBuffer', # noqa - postprocess_data_fn: Optional[Callable] = None + postprocess_data_fn: Optional[Callable] = None, + collect_kwargs: Optional[Dict] = None, ) -> None: # noqa assert policy_cfg.random_collect_size > 0 if policy_cfg.get('transition_with_policy_data', False): @@ -55,7 +56,8 @@ def random_collect( action_space = collector_env.action_space random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) collector.reset_policy(random_policy) - collect_kwargs = commander.step() + if collect_kwargs is None: + collect_kwargs = commander.step() if policy_cfg.collect.collector.type == 'episode': new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs) else: diff --git a/ding/envs/env_manager/envpool_env_manager.py b/ding/envs/env_manager/envpool_env_manager.py index a8d1a4ae03..f235890d60 100644 --- a/ding/envs/env_manager/envpool_env_manager.py +++ b/ding/envs/env_manager/envpool_env_manager.py @@ -63,10 +63,8 @@ def launch(self) -> None: seed=seed, episodic_life=self._cfg.episodic_life, reward_clip=self._cfg.reward_clip, - stack_num=self._cfg.stack_num, - gray_scale=self._cfg.gray_scale, - frame_skip=self._cfg.frame_skip ) + self.action_space = self._envs.action_space self._closed = False self.reset() diff --git a/ding/worker/collector/sample_serial_collector.py b/ding/worker/collector/sample_serial_collector.py index 26db458edb..ccea8806cf 100644 --- a/ding/worker/collector/sample_serial_collector.py +++ b/ding/worker/collector/sample_serial_collector.py @@ -25,7 +25,7 @@ class SampleSerialCollector(ISerialCollector): envstep """ - config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100) + config = dict(type='sample', deepcopy_obs=False, transform_obs=False, collect_print_freq=100) def __init__( self, diff --git a/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py b/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py index 0b80e41548..cd7cd60d86 100644 --- a/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py +++ b/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py @@ -9,13 +9,12 @@ evaluator_batch_size=8, n_evaluator_episode=8, stop_value=20, - env_id='PongNoFrameskip-v4', - #'ALE/Pong-v5' is available. But special setting is needed after gym make. - frame_stack=4, + env_id='Pong-v5', ), policy=dict( cuda=True, priority=False, + random_collect_size=5000, model=dict( obs_shape=[4, 84, 84], action_shape=6, @@ -55,8 +54,4 @@ ) pong_dqn_envpool_create_config = EasyDict(pong_dqn_envpool_create_config) create_config = pong_dqn_envpool_create_config - -if __name__ == '__main__': - # or you can enter `ding -m serial -c pong_dqn_envpool_config.py -s 0` - from ding.entry import serial_pipeline - serial_pipeline((main_config, create_config), seed=0) +# You can use `dizoo/atari/entry/atari_dqn_envpool_main.py` to run this config. diff --git a/dizoo/atari/entry/pong_dqn_envpool_main.py b/dizoo/atari/entry/atari_dqn_envpool_main.py similarity index 88% rename from dizoo/atari/entry/pong_dqn_envpool_main.py rename to dizoo/atari/entry/atari_dqn_envpool_main.py index 769fe4f261..bd079b4711 100644 --- a/dizoo/atari/entry/pong_dqn_envpool_main.py +++ b/dizoo/atari/entry/atari_dqn_envpool_main.py @@ -6,14 +6,15 @@ from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer from ding.envs.env_manager.envpool_env_manager import PoolEnvManager from ding.policy import DQNPolicy +from ding.entry import random_collect from ding.model import DQN from ding.utils import set_pkg_seed from ding.rl_utils import get_epsilon_greedy_fn from dizoo.atari.config.serial import pong_dqn_envpool_config -def main(cfg, seed=0, max_iterations=int(1e10)): - cfg.exp_name = 'atari_dqn_envpool' +def main(cfg, seed=2, max_iterations=int(1e10)): + cfg.exp_name = 'atari_dqn_envpool_pong_two_true_seed2' cfg = compile_config( cfg, PoolEnvManager, @@ -32,9 +33,6 @@ def main(cfg, seed=0, max_iterations=int(1e10)): # env wrappers 'episodic_life': True, # collector: True 'reward_clip': True, # collector: True - 'gray_scale': cfg.env.get('gray_scale', True), - 'stack_num': cfg.env.get('stack_num', 4), - 'frame_skip': cfg.env.get('frame_skip', 4), } ) collector_env = PoolEnvManager(collector_env_cfg) @@ -46,9 +44,6 @@ def main(cfg, seed=0, max_iterations=int(1e10)): # env wrappers 'episodic_life': False, # evaluator: False 'reward_clip': False, # evaluator: False - 'gray_scale': cfg.env.get('gray_scale', True), - 'stack_num': cfg.env.get('stack_num', 4), - 'frame_skip': cfg.env.get('frame_skip', 4), } ) evaluator_env = PoolEnvManager(evaluator_env_cfg) @@ -72,6 +67,9 @@ def main(cfg, seed=0, max_iterations=int(1e10)): eps_cfg = cfg.policy.other.eps epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) + if cfg.policy.random_collect_size > 0: + collect_kwargs = {'eps': epsilon_greedy(collector.envstep)} + random_collect(cfg.policy, policy, collector, collector_env, {}, replay_buffer, collect_kwargs=collect_kwargs) while True: if evaluator.should_eval(learner.train_iter): stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) @@ -85,6 +83,8 @@ def main(cfg, seed=0, max_iterations=int(1e10)): train_data = replay_buffer.sample(batch_size, learner.train_iter) if train_data is not None: learner.train(train_data, collector.envstep) + if collector.envstep >= int(3e6): + break if __name__ == "__main__": diff --git a/dizoo/atari/entry/phoenix_fqf_main.py b/dizoo/atari/entry/phoenix_fqf_main.py deleted file mode 100644 index 7d8ab5a384..0000000000 --- a/dizoo/atari/entry/phoenix_fqf_main.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import torch -from tensorboardX import SummaryWriter -from ding.config import compile_config -from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer -from ding.policy import FQFPolicy -from ding.model import FQF -from ding.utils import set_pkg_seed -from ding.rl_utils import get_epsilon_greedy_fn -from dizoo.atari.config.serial.phoenix.phoenix_fqf_config import phoenix_fqf_config, create_config -from ding.utils import DistContext -from functools import partial -from ding.envs import get_vec_env_setting, create_env_manager - - -def main(cfg, create_cfg, seed=0): - - cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - - # Create main components: env, policy - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) - evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) - - # Set random seed for all package and instance - collector_env.seed(seed) - evaluator_env.seed(seed, dynamic_seed=False) - set_pkg_seed(seed, use_cuda=cfg.policy.cuda) - - # Set up RL Policy - model = FQF(**cfg.policy.model) - policy = FQFPolicy(cfg.policy, model=model) - - # Set up collection, training and evaluation utilities - tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) - learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) - collector = SampleSerialCollector( - cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name - ) - evaluator = InteractionSerialEvaluator( - cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name - ) - replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name) - - # Set up other modules, etc. epsilon greedy - eps_cfg = cfg.policy.other.eps - epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) - - # Training & Evaluation loop - while True: - # Evaluating at the beginning and with specific frequency - if evaluator.should_eval(learner.train_iter): - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) - if stop: - break - # Update other modules - eps = epsilon_greedy(collector.envstep) - # Sampling data from environments - new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps}) - replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) - # Training - for i in range(cfg.policy.learn.update_per_collect): - train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) - if train_data is None: - break - learner.train(train_data, collector.envstep) - - if collector.envstep >= int(1e7): - break - - -if __name__ == "__main__": - # with DistContext(): - main(phoenix_fqf_config, create_config) diff --git a/dizoo/atari/entry/phoenix_iqn_main.py b/dizoo/atari/entry/phoenix_iqn_main.py deleted file mode 100644 index 91f528505d..0000000000 --- a/dizoo/atari/entry/phoenix_iqn_main.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import torch -from tensorboardX import SummaryWriter -from ding.config import compile_config -from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer -from ding.policy import IQNPolicy -from ding.model import IQN -from ding.utils import set_pkg_seed -from ding.rl_utils import get_epsilon_greedy_fn -from dizoo.atari.config.serial.phoenix.phoenix_iqn_config import phoenix_iqn_config, create_config -from ding.utils import DistContext -from functools import partial -from ding.envs import get_vec_env_setting, create_env_manager - - -def main(cfg, create_cfg, seed=0): - - cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - - # Create main components: env, policy - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) - evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) - - # Set random seed for all package and instance - collector_env.seed(seed) - evaluator_env.seed(seed, dynamic_seed=False) - set_pkg_seed(seed, use_cuda=cfg.policy.cuda) - - # Set up RL Policy - model = IQN(**cfg.policy.model) - policy = IQNPolicy(cfg.policy, model=model) - - # Set up collection, training and evaluation utilities - tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) - learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) - collector = SampleSerialCollector( - cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name - ) - evaluator = InteractionSerialEvaluator( - cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name - ) - replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name) - - # Set up other modules, etc. epsilon greedy - eps_cfg = cfg.policy.other.eps - epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) - - # Training & Evaluation loop - while True: - # Evaluating at the beginning and with specific frequency - if evaluator.should_eval(learner.train_iter): - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) - if stop: - break - # Update other modules - eps = epsilon_greedy(collector.envstep) - # Sampling data from environments - new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps}) - replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) - # Training - for i in range(cfg.policy.learn.update_per_collect): - train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) - if train_data is None: - break - learner.train(train_data, collector.envstep) - - if collector.envstep >= int(1e7): - break - - -if __name__ == "__main__": - # with DistContext(): - main(phoenix_iqn_config, create_config) diff --git a/dizoo/atari/entry/pong_fqf_main.py b/dizoo/atari/entry/pong_fqf_main.py deleted file mode 100644 index 816ec566c4..0000000000 --- a/dizoo/atari/entry/pong_fqf_main.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import torch -from tensorboardX import SummaryWriter -from ding.config import compile_config -from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer -from ding.policy import FQFPolicy -from ding.model import FQF -from ding.utils import set_pkg_seed -from ding.rl_utils import get_epsilon_greedy_fn -from dizoo.atari.config.serial.pong.pong_fqf_config import pong_fqf_config, create_config -from ding.utils import DistContext -from functools import partial -from ding.envs import get_vec_env_setting, create_env_manager - - -def main(cfg, create_cfg, seed=0): - - cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - - # Create main components: env, policy - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) - evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) - - # Set random seed for all package and instance - collector_env.seed(seed) - evaluator_env.seed(seed, dynamic_seed=False) - set_pkg_seed(seed, use_cuda=cfg.policy.cuda) - - # Set up RL Policy - model = FQF(**cfg.policy.model) - policy = FQFPolicy(cfg.policy, model=model) - - # Set up collection, training and evaluation utilities - tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) - learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) - collector = SampleSerialCollector( - cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name - ) - evaluator = InteractionSerialEvaluator( - cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name - ) - replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name) - - # Set up other modules, etc. epsilon greedy - eps_cfg = cfg.policy.other.eps - epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) - - # Training & Evaluation loop - while True: - # Evaluating at the beginning and with specific frequency - if evaluator.should_eval(learner.train_iter): - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) - if stop: - break - # Update other modules - eps = epsilon_greedy(collector.envstep) - # Sampling data from environments - new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps}) - replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) - # Training - for i in range(cfg.policy.learn.update_per_collect): - train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) - if train_data is None: - break - learner.train(train_data, collector.envstep) - - if collector.envstep >= 10000000: - break - - -if __name__ == "__main__": - # with DistContext(): - main(pong_fqf_config, create_config) diff --git a/dizoo/atari/entry/qbert_fqf_main.py b/dizoo/atari/entry/qbert_fqf_main.py deleted file mode 100644 index 6c87c549a8..0000000000 --- a/dizoo/atari/entry/qbert_fqf_main.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import torch -from tensorboardX import SummaryWriter -from ding.config import compile_config -from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer -from ding.policy import FQFPolicy -from ding.model import FQF -from ding.utils import set_pkg_seed -from ding.rl_utils import get_epsilon_greedy_fn -from dizoo.atari.config.serial.qbert.qbert_fqf_config import qbert_fqf_config, create_config -from ding.utils import DistContext -from functools import partial -from ding.envs import get_vec_env_setting, create_env_manager - - -def main(cfg, create_cfg, seed=0): - - cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - - # Create main components: env, policy - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) - evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) - - # Set random seed for all package and instance - collector_env.seed(seed) - evaluator_env.seed(seed, dynamic_seed=False) - set_pkg_seed(seed, use_cuda=cfg.policy.cuda) - - # Set up RL Policy - model = FQF(**cfg.policy.model) - policy = FQFPolicy(cfg.policy, model=model) - - # Set up collection, training and evaluation utilities - tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) - learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) - collector = SampleSerialCollector( - cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name - ) - evaluator = InteractionSerialEvaluator( - cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name - ) - replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name) - - # Set up other modules, etc. epsilon greedy - eps_cfg = cfg.policy.other.eps - epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) - - # Training & Evaluation loop - while True: - # Evaluating at the beginning and with specific frequency - if evaluator.should_eval(learner.train_iter): - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) - if stop: - break - # Update other modules - eps = epsilon_greedy(collector.envstep) - # Sampling data from environments - new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps}) - replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) - # Training - for i in range(cfg.policy.learn.update_per_collect): - train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) - if train_data is None: - break - learner.train(train_data, collector.envstep) - - if collector.envstep >= 10000000: - break - - -if __name__ == "__main__": - # with DistContext(): - main(qbert_fqf_config, create_config) diff --git a/dizoo/atari/entry/spaceinvaders_fqf_main.py b/dizoo/atari/entry/spaceinvaders_fqf_main.py deleted file mode 100644 index 3c46e87019..0000000000 --- a/dizoo/atari/entry/spaceinvaders_fqf_main.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import torch -from tensorboardX import SummaryWriter -from ding.config import compile_config -from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer -from ding.policy import FQFPolicy -from ding.model import FQF -from ding.utils import set_pkg_seed -from ding.rl_utils import get_epsilon_greedy_fn -from dizoo.atari.config.serial.spaceinvaders.spaceinvaders_fqf_config import spaceinvaders_fqf_config, create_config -from ding.utils import DistContext -from functools import partial -from ding.envs import get_vec_env_setting, create_env_manager - - -def main(cfg, create_cfg, seed=0): - - cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - - # Create main components: env, policy - env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) - collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) - evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) - - # Set random seed for all package and instance - collector_env.seed(seed) - evaluator_env.seed(seed, dynamic_seed=False) - set_pkg_seed(seed, use_cuda=cfg.policy.cuda) - - # Set up RL Policy - model = FQF(**cfg.policy.model) - policy = FQFPolicy(cfg.policy, model=model) - - # Set up collection, training and evaluation utilities - tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) - learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) - collector = SampleSerialCollector( - cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name - ) - evaluator = InteractionSerialEvaluator( - cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name - ) - replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name) - - # Set up other modules, etc. epsilon greedy - eps_cfg = cfg.policy.other.eps - epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type) - - # Training & Evaluation loop - while True: - # Evaluating at the beginning and with specific frequency - if evaluator.should_eval(learner.train_iter): - stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) - if stop: - break - # Update other modules - eps = epsilon_greedy(collector.envstep) - # Sampling data from environments - new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps}) - replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) - # Training - for i in range(cfg.policy.learn.update_per_collect): - train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter) - if train_data is None: - break - learner.train(train_data, collector.envstep) - - if collector.envstep >= 10000000: - break - - -if __name__ == "__main__": - # with DistContext(): - main(spaceinvaders_fqf_config, create_config)