Skip to content

Commit

Permalink
feature(xyy):add HPT model and test_hpt
Browse files Browse the repository at this point in the history
  • Loading branch information
luodi-7 committed Dec 4, 2024
1 parent 31f2398 commit 9d30de8
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions ding/model/template/tests/test_hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from ding.torch_utils import is_differentiable

T, B = 3, 4
obs_shape = [4, (8, ), (4, 64, 64)] # Example observation shapes
act_shape = [3, (6, ), [2, 3, 6]] # Example action shapes
obs_shape = [4, (8, ), (4, 64, 64)]
act_shape = [3, (6, ), [2, 3, 6]]
args = list(product(*[obs_shape, act_shape]))


Expand All @@ -26,16 +26,27 @@ def output_check(self, model, outputs):
def test_hpt(self, obs_shape, act_shape):
if isinstance(obs_shape, int):
inputs = torch.randn(B, obs_shape)
state_dim = obs_shape
else:
inputs = torch.randn(B, *obs_shape)
model = HPT(state_dim=obs_shape, action_dim=act_shape)
state_dim = obs_shape[0]

if isinstance(act_shape, int):
action_dim = act_shape
else:
action_dim = len(act_shape)

model = HPT(state_dim=state_dim, action_dim=action_dim)
outputs = model(inputs)

assert isinstance(outputs, torch.Tensor)

if isinstance(act_shape, int):
assert outputs.shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs.shape == (B, *act_shape)
else:
for i, s in enumerate(act_shape):
assert outputs[i].shape == (B, s)

self.output_check(model, outputs)

0 comments on commit 9d30de8

Please sign in to comment.