From 48ee6da58ba460e25f4010149b5d81d1d978539a Mon Sep 17 00:00:00 2001 From: zjowowen Date: Thu, 23 Nov 2023 15:03:28 +0800 Subject: [PATCH] polish code --- ding/example/dqn_nstep_envpool.py | 7 +++---- ding/framework/middleware/collector.py | 8 ++++---- ding/framework/middleware/functional/data_processor.py | 1 - ding/framework/middleware/functional/evaluator.py | 2 +- ding/framework/middleware/learner.py | 10 +++++----- ding/policy/common_utils.py | 4 ++-- 6 files changed, 15 insertions(+), 17 deletions(-) diff --git a/ding/example/dqn_nstep_envpool.py b/ding/example/dqn_nstep_envpool.py index c5695f900f..7ab7a74677 100644 --- a/ding/example/dqn_nstep_envpool.py +++ b/ding/example/dqn_nstep_envpool.py @@ -80,10 +80,9 @@ def main(cfg): cfg, policy.collect_mode, collector_env, - random_collect_size=cfg.policy.random_collect_size \ - if hasattr(cfg.policy, 'random_collect_size') else 0, - ) - ) + random_collect_size=cfg.policy.random_collect_size if hasattr(cfg.policy, 'random_collect_size') else 0, + ) + ) task.use(data_pusher(cfg, buffer_)) task.use(EnvpoolOffPolicyLearner(cfg, policy, buffer_)) task.use(online_logger(train_show_freq=10)) diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index d6a11bcb61..ed58c80993 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -187,7 +187,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None: ) if self._nsteps > 1: - if done[i] == False and counter < target_size: + if done[i] is False and counter < target_size: reverse_record_position = min(self._nsteps, len(self._trajectory[env_id_receive[i]])) real_reverse_record_position = reverse_record_position @@ -195,7 +195,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None: if j == 1: pass else: - if self._trajectory[env_id_receive[i]][-j]['done'] == True: + if self._trajectory[env_id_receive[i]][-j]['done'] is True: real_reverse_record_position = j - 1 break else: @@ -207,7 +207,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None: self._trajectory[env_id_receive[i]][-real_reverse_record_position][ 'value_gamma'] = self._discount_ratio_list[real_reverse_record_position - 1] - else: # done[i] == True or counter >= target_size + else: # done[i] is True or counter >= target_size reverse_record_position = min(self._nsteps, len(self._trajectory[env_id_receive[i]])) real_reverse_record_position = reverse_record_position @@ -224,7 +224,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None: self._trajectory[env_id_receive[i]][-j]['value_gamma'] = self._discount_ratio_list[j - 1] else: - if self._trajectory[env_id_receive[i]][-j]['done'] == True: + if self._trajectory[env_id_receive[i]][-j]['done'] is True: real_reverse_record_position = j break else: diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 420af3d4fd..e254e4ad3b 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -284,7 +284,6 @@ def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable: from threading import Thread from queue import Queue - import time stream = torch.cuda.Stream() def producer(queue, dataset, batch_size, device): diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 37093a3a67..31b70153a7 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -385,7 +385,7 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): } ) - if done[i] == True: + if done[i] is True: episode_return_i = 0.0 for item in trajectory[env_id_receive[i]]: episode_return_i += item['reward'][0] diff --git a/ding/framework/middleware/learner.py b/ding/framework/middleware/learner.py index 7182a9d5df..8779ced315 100644 --- a/ding/framework/middleware/learner.py +++ b/ding/framework/middleware/learner.py @@ -36,11 +36,11 @@ def data_process_func( else: output_data = fast_preprocess_learn( data, - use_priority=use_priority, #policy._cfg.priority, - use_priority_IS_weight=use_priority_IS_weight, #policy._cfg.priority_IS_weight, - use_nstep=use_nstep, #policy._cfg.nstep > 1, - cuda=cuda, #policy._cuda, - device=device, #policy._device, + use_priority=use_priority, + use_priority_IS_weight=use_priority_IS_weight, + use_nstep=use_nstep, + cuda=cuda, + device=device, ) data_queue_output.put(output_data) diff --git a/ding/policy/common_utils.py b/ding/policy/common_utils.py index 97a0084b37..a8e095974c 100644 --- a/ding/policy/common_utils.py +++ b/ding/policy/common_utils.py @@ -1,7 +1,6 @@ from typing import List, Any, Dict, Callable import numpy as np import torch -import numpy as np import treetensor.torch as ttorch from ding.utils.data import default_collate from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze, to_device @@ -82,7 +81,8 @@ def fast_preprocess_learn( Overview: Fast data pre-processing before policy's ``_forward_learn`` method, including stacking batch data, transform \ data to PyTorch Tensor and move data to GPU, etc. This function is faster than ``default_preprocess_learn`` \ - but less flexible. This function abandons calling ``default_collate`` to stack data because ``default_collate`` \ + but less flexible. + This function abandons calling ``default_collate`` to stack data because ``default_collate`` \ is recursive and cumbersome. In this function, we alternatively stack the data and send it to GPU, so that it \ is faster. In addition, this function is usually used in a special data process thread in learner. Arguments: