Skip to content

Commit

Permalink
polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Nov 23, 2023
1 parent 4fb85b0 commit 48ee6da
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 17 deletions.
7 changes: 3 additions & 4 deletions ding/example/dqn_nstep_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,9 @@ def main(cfg):
cfg,
policy.collect_mode,
collector_env,
random_collect_size=cfg.policy.random_collect_size \
if hasattr(cfg.policy, 'random_collect_size') else 0,
)
)
random_collect_size=cfg.policy.random_collect_size if hasattr(cfg.policy, 'random_collect_size') else 0,
)
)
task.use(data_pusher(cfg, buffer_))
task.use(EnvpoolOffPolicyLearner(cfg, policy, buffer_))
task.use(online_logger(train_show_freq=10))
Expand Down
8 changes: 4 additions & 4 deletions ding/framework/middleware/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,15 +187,15 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
)

if self._nsteps > 1:
if done[i] == False and counter < target_size:
if done[i] is False and counter < target_size:
reverse_record_position = min(self._nsteps, len(self._trajectory[env_id_receive[i]]))
real_reverse_record_position = reverse_record_position

for j in range(1, reverse_record_position + 1):
if j == 1:
pass
else:
if self._trajectory[env_id_receive[i]][-j]['done'] == True:
if self._trajectory[env_id_receive[i]][-j]['done'] is True:
real_reverse_record_position = j - 1
break
else:
Expand All @@ -207,7 +207,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
self._trajectory[env_id_receive[i]][-real_reverse_record_position][
'value_gamma'] = self._discount_ratio_list[real_reverse_record_position - 1]

else: # done[i] == True or counter >= target_size
else: # done[i] is True or counter >= target_size

reverse_record_position = min(self._nsteps, len(self._trajectory[env_id_receive[i]]))
real_reverse_record_position = reverse_record_position
Expand All @@ -224,7 +224,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
self._trajectory[env_id_receive[i]][-j]['value_gamma'] = self._discount_ratio_list[j -
1]
else:
if self._trajectory[env_id_receive[i]][-j]['done'] == True:
if self._trajectory[env_id_receive[i]][-j]['done'] is True:
real_reverse_record_position = j
break
else:
Expand Down
1 change: 0 additions & 1 deletion ding/framework/middleware/functional/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,6 @@ def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable:

from threading import Thread
from queue import Queue
import time
stream = torch.cuda.Stream()

def producer(queue, dataset, batch_size, device):
Expand Down
2 changes: 1 addition & 1 deletion ding/framework/middleware/functional/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
}
)

if done[i] == True:
if done[i] is True:
episode_return_i = 0.0
for item in trajectory[env_id_receive[i]]:
episode_return_i += item['reward'][0]
Expand Down
10 changes: 5 additions & 5 deletions ding/framework/middleware/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def data_process_func(
else:
output_data = fast_preprocess_learn(
data,
use_priority=use_priority, #policy._cfg.priority,
use_priority_IS_weight=use_priority_IS_weight, #policy._cfg.priority_IS_weight,
use_nstep=use_nstep, #policy._cfg.nstep > 1,
cuda=cuda, #policy._cuda,
device=device, #policy._device,
use_priority=use_priority,
use_priority_IS_weight=use_priority_IS_weight,
use_nstep=use_nstep,
cuda=cuda,
device=device,
)
data_queue_output.put(output_data)

Expand Down
4 changes: 2 additions & 2 deletions ding/policy/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List, Any, Dict, Callable
import numpy as np
import torch
import numpy as np
import treetensor.torch as ttorch
from ding.utils.data import default_collate
from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze, to_device
Expand Down Expand Up @@ -82,7 +81,8 @@ def fast_preprocess_learn(
Overview:
Fast data pre-processing before policy's ``_forward_learn`` method, including stacking batch data, transform \
data to PyTorch Tensor and move data to GPU, etc. This function is faster than ``default_preprocess_learn`` \
but less flexible. This function abandons calling ``default_collate`` to stack data because ``default_collate`` \
but less flexible.
This function abandons calling ``default_collate`` to stack data because ``default_collate`` \
is recursive and cumbersome. In this function, we alternatively stack the data and send it to GPU, so that it \
is faster. In addition, this function is usually used in a special data process thread in learner.
Arguments:
Expand Down

0 comments on commit 48ee6da

Please sign in to comment.