From 7451f462536b6ac8d7350b74f135c356abd9ea9a Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Mon, 21 Aug 2023 19:03:05 +0800 Subject: [PATCH] fix(nyz): fix unittest bugs --- ...ial_entry_preference_based_irl_onpolicy.py | 6 +-- ding/envs/env/tests/test_ding_env_wrapper.py | 2 +- ding/framework/tests/test_parallel.py | 2 +- .../tests/test_decision_transformer.py | 9 ++--- ding/policy/dt.py | 37 +++++++++---------- 5 files changed, 26 insertions(+), 30 deletions(-) diff --git a/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py b/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py index ff0e88b0d5..ffc20b9899 100644 --- a/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py +++ b/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py @@ -15,16 +15,16 @@ @pytest.mark.unittest def test_serial_pipeline_trex_onpolicy(): - exp_name = 'test_serial_pipeline_trex_onpolicy_expert' + exp_name = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_expert' config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)] config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100 config[0].exp_name = exp_name expert_policy = serial_pipeline_onpolicy(config, seed=0) - exp_name = 'test_serial_pipeline_trex_onpolicy_collect' + exp_name = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_collect' config = [deepcopy(cartpole_trex_ppo_onpolicy_config), deepcopy(cartpole_trex_ppo_onpolicy_create_config)] config[0].exp_name = exp_name - config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_onpolicy_expert' + config[0].reward_model.expert_model_path = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_expert' config[0].reward_model.checkpoint_max = 100 config[0].reward_model.checkpoint_step = 100 config[0].reward_model.num_snippets = 100 diff --git a/ding/envs/env/tests/test_ding_env_wrapper.py b/ding/envs/env/tests/test_ding_env_wrapper.py index a99e4ace23..7d53adbfd3 100644 --- a/ding/envs/env/tests/test_ding_env_wrapper.py +++ b/ding/envs/env/tests/test_ding_env_wrapper.py @@ -181,7 +181,7 @@ def test_hybrid(self): print('random_action', action) assert isinstance(action, dict) - @pytest.mark.unittest + @pytest.mark.envtest def test_AllinObsWrapper(self): env_cfg = EasyDict(env_id='PongNoFrameskip-v4', env_wrapper='reward_in_obs') ding_env_aio = DingEnvWrapper(cfg=env_cfg) diff --git a/ding/framework/tests/test_parallel.py b/ding/framework/tests/test_parallel.py index 429072a3fc..8d2cf648c2 100644 --- a/ding/framework/tests/test_parallel.py +++ b/ding/framework/tests/test_parallel.py @@ -24,7 +24,7 @@ def test_callback(key): time.sleep(0.7) -@pytest.mark.unittest +@pytest.mark.tmp def test_parallel_run(): Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main) Parallel.runner(n_parallel_workers=2, protocol="tcp", startup_interval=0.1)(parallel_main) diff --git a/ding/model/template/tests/test_decision_transformer.py b/ding/model/template/tests/test_decision_transformer.py index 69dfb6738c..0ee054d176 100644 --- a/ding/model/template/tests/test_decision_transformer.py +++ b/ding/model/template/tests/test_decision_transformer.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from ding.model.template import DecisionTransformer -from ding.torch_utils import is_differentiable, one_hot +from ding.torch_utils import is_differentiable args = ['continuous', 'discrete'] @@ -23,6 +23,7 @@ def test_decision_transformer(action_space): context_len=T, n_heads=2, drop_p=0.1, + continuous=(action_space == 'continuous') ) is_continuous = True if action_space == 'continuous' else False @@ -40,15 +41,11 @@ def test_decision_transformer(action_space): # all ones since no padding traj_mask = torch.ones([B, T], dtype=torch.long) # B x T - # if discrete - if not is_continuous: - actions = one_hot(actions.squeeze(-1), num=act_dim) - - assert actions.shape == (B, T, act_dim) if is_continuous: assert action_target.shape == (B, T, act_dim) else: assert action_target.shape == (B, T, 1) + actions = actions.squeeze(-1) returns_to_go = returns_to_go.float() state_preds, action_preds, return_preds = DT_model.forward( diff --git a/ding/policy/dt.py b/ding/policy/dt.py index 771f383bc5..adef441820 100644 --- a/ding/policy/dt.py +++ b/ding/policy/dt.py @@ -69,8 +69,10 @@ def _init_learn(self) -> None: self.act_dim = self._cfg.model.act_dim self._learn_model = self._model + self._atari_env = 'state_mean' not in self._cfg + self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg - if 'state_mean' not in self._cfg: + if self._atari_env: self._optimizer = self._learn_model.configure_optimizers(wt_decay, lr) else: self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay) @@ -93,22 +95,18 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: self._learn_model.train() timesteps, states, actions, returns_to_go, traj_mask = data - if actions.dtype is not torch.long: - actions = actions.to(torch.long) - action_target = torch.clone(actions).detach().to(self._device) # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1), # and we need a 3-dim tensor if len(returns_to_go.shape) == 2: returns_to_go = returns_to_go.unsqueeze(-1) - # if discrete - if not self._cfg.model.continuous and 'state_mean' in self._cfg: - # actions = one_hot(actions.squeeze(-1), num=self.act_dim) + if self._basic_discrete_env: + actions = actions.to(torch.long) actions = actions.squeeze(-1) action_target = torch.clone(actions).detach().to(self._device) - if 'state_mean' not in self._cfg: + if self._atari_env: state_preds, action_preds, return_preds = self._learn_model.forward( timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1 ) @@ -117,7 +115,7 @@ def _forward_learn(self, data: list) -> Dict[str, Any]: timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go ) - if 'state_mean' not in self._cfg: + if self._atari_env: action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1)) else: traj_mask = traj_mask.view(-1, ) @@ -171,7 +169,9 @@ def _init_eval(self) -> None: self.actions = torch.zeros( (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device ) - if 'state_mean' not in self._cfg: + self._atari_env = 'state_mean' not in self._cfg + self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg + if self._atari_env: self.states = torch.zeros( ( self.eval_batch_size, @@ -201,7 +201,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: self._eval_model.eval() with torch.no_grad(): - if 'state_mean' not in self._cfg: + if self._atari_env: states = torch.zeros( ( self.eval_batch_size, @@ -228,7 +228,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: (self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device ) for i in data_id: - if 'state_mean' not in self._cfg: + if self._atari_env: self.states[i, self.t[i]] = data[i]['obs'].to(self._device) else: self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std @@ -236,7 +236,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: self.rewards_to_go[i, self.t[i]] = self.running_rtg[i] if self.t[i] <= self.context_len: - if 'state_mean' not in self._cfg: + if self._atari_env: timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( (1, 1), dtype=torch.int64 ).to(self._device) @@ -246,7 +246,7 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: actions[i] = self.actions[i, :self.context_len] rewards_to_go[i] = self.rewards_to_go[i, :self.context_len] else: - if 'state_mean' not in self._cfg: + if self._atari_env: timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones( (1, 1), dtype=torch.int64 ).to(self._device) @@ -255,15 +255,14 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1] actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1] rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1] - if not self._cfg.model.continuous and 'state_mean' in self._cfg: - # actions = one_hot(actions.squeeze(-1), num=self.act_dim) + if self._basic_discrete_env: actions = actions.squeeze(-1) _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go) del timesteps, states, actions, rewards_to_go logits = act_preds[:, -1, :] if not self._cfg.model.continuous: - if 'state_mean' not in self._cfg: + if self._atari_env: probs = F.softmax(logits, dim=-1) act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device) for i in data_id: @@ -297,7 +296,7 @@ def _reset_eval(self, data_id: List[int] = None) -> None: dtype=torch.float32, device=self._device ) - if 'state_mean' not in self._cfg: + if self._atari_env: self.states = torch.zeros( ( self.eval_batch_size, @@ -327,7 +326,7 @@ def _reset_eval(self, data_id: List[int] = None) -> None: self.actions[i] = torch.zeros( (self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device ) - if 'state_mean' not in self._cfg: + if self._atari_env: self.states[i] = torch.zeros( (self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device )