Skip to content

Commit

Permalink
fix(nyz): fix middleware collector env reset bug (#845)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Dec 3, 2024
1 parent 548406f commit e93b5a6
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
2 changes: 2 additions & 0 deletions ding/framework/middleware/functional/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def _rollout(ctx: "OnlineRLContext"):
'step': env_info[timestep.env_id.item()]['step'],
'train_sample': env_info[timestep.env_id.item()]['train_sample'],
}
# reset corresponding env info
env_info[timestep.env_id.item()] = {'time': 0., 'step': 0, 'train_sample': 0}

episode_info.append(info)
policy.reset([timestep.env_id.item()])
Expand Down
18 changes: 15 additions & 3 deletions ding/framework/middleware/tests/mock_for_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union, Any, List, Callable, Dict, Optional
from collections import namedtuple
import random
import torch
import treetensor.numpy as tnp
from easydict import EasyDict
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(self) -> None:
self.obs_dim = obs_dim
self.closed = False
self._reward_grow_indicator = 1
self._steps = [0 for _ in range(self.env_num)]

@property
def ready_obs(self) -> tnp.array:
Expand All @@ -90,16 +92,26 @@ def launch(self, reset_param: Optional[Dict] = None) -> None:
return

def reset(self, reset_param: Optional[Dict] = None) -> None:
return
self._steps = [0 for _ in range(self.env_num)]

def step(self, actions: tnp.ndarray) -> List[tnp.ndarray]:
timesteps = []
for i in range(self.env_num):
if self._steps[i] < 5:
done = False
elif self._steps[i] < 10:
done = random.random() > 0.5
else:
done = True
if done:
self._steps[i] = 0
else:
self._steps[i] += 1
timestep = dict(
obs=torch.rand(self.obs_dim),
reward=1.0,
done=True,
info={'eval_episode_return': self._reward_grow_indicator * 1.0},
done=done,
info={'eval_episode_return': self._reward_grow_indicator * 1.0} if done else {},
env_id=i,
)
timesteps.append(tnp.array(timestep))
Expand Down
13 changes: 8 additions & 5 deletions ding/framework/middleware/tests/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ def test_inferencer():

@pytest.mark.unittest
def test_rolloutor():
N = 20
ctx = OnlineRLContext()
transitions = TransitionList(2)
with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv):
policy = MockPolicy()
env = MockEnv()
for _ in range(10):
inferencer(0, policy, env)(ctx)
rolloutor(policy, env, transitions)(ctx)
assert ctx.env_episode == 20 # 10 * env_num
assert ctx.env_step == 20 # 10 * env_num
i = inferencer(0, policy, env)
r = rolloutor(policy, env, transitions)
for _ in range(N):
i(ctx)
r(ctx)
assert ctx.env_step == N * 2 # N * env_num
assert ctx.env_episode >= N // 10 * 2 # N * env_num


@pytest.mark.unittest
Expand Down

0 comments on commit e93b5a6

Please sign in to comment.