Skip to content

Commit

Permalink
feature(xyy):add HPT model and test_hpt
Browse files Browse the repository at this point in the history
  • Loading branch information
luodi-7 committed Dec 4, 2024
1 parent 25f1d2f commit 9608131
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 27 deletions.
12 changes: 4 additions & 8 deletions ding/model/template/hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def init_cross_attn(self):
"""Initialize cross-attention module and learnable tokens."""
token_num = 16
self.tokens = nn.Parameter(torch.randn(1, token_num, 128) * INIT_CONST)
self.cross_attention = CrossAttention(
128, heads=8, dim_head=64, dropout=0.1)
self.cross_attention = CrossAttention(128, heads=8, dim_head=64, dropout=0.1)

def compute_latent(self, x: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -113,12 +112,10 @@ def compute_latent(self, x: torch.Tensor) -> torch.Tensor:
"""
# Using the Feature Extractor
stem_feat = self.feature_extractor(x)
stem_feat = stem_feat.reshape(
stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128)
stem_feat = stem_feat.reshape(stem_feat.shape[0], -1, stem_feat.shape[-1]) # (B, N, 128)
# Calculating latent tokens using CrossAttention
stem_tokens = self.tokens.repeat(len(stem_feat), 1, 1) # (B, 16, 128)
stem_tokens = self.cross_attention(
stem_tokens, stem_feat) # (B, 16, 128)
stem_tokens = self.cross_attention(stem_tokens, stem_feat) # (B, 16, 128)
return stem_tokens

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -198,8 +195,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.T
h = self.heads
q = self.to_q(x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q, k, v = map(lambda t: rearrange(
t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale

if mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion dizoo/box2d/lunarlander/config/lunarlander_sqil_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@
from dizoo.box2d.lunarlander.config import lunarlander_dqn_config, lunarlander_dqn_create_config
expert_main_config = lunarlander_dqn_config
expert_create_config = lunarlander_dqn_create_config
serial_pipeline_sqil([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
serial_pipeline_sqil([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
12 changes: 4 additions & 8 deletions dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,16 @@

def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config,
auto=True, save_cfg=task.router.node_id == 0)
cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0)
ding_init(cfg)

with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = SubprocessEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(
gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)],
env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = SubprocessEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(
gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)],
env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

Expand All @@ -41,8 +38,7 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DQN(**cfg.policy.model).to(device)

buffer_ = DequeBuffer(
size=cfg.policy.other.replay_buffer.replay_buffer_size)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)

# Pass the model into Policy
policy = DQNPolicy(cfg.policy, model=model)
Expand Down
15 changes: 5 additions & 10 deletions dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,16 @@

def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config,
auto=True, save_cfg=task.router.node_id == 0)
cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0)
ding_init(cfg)

with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = SubprocessEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(
gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)],
env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = SubprocessEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(
gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)],
env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

Expand All @@ -42,10 +39,8 @@ def main():
# Migrating models to the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# HPT introduces a Policy Stem module, which processes the input features using Cross-Attention.
model = HPT(cfg.policy.model.obs_shape,
cfg.policy.model.action_shape).to(device)
buffer_ = DequeBuffer(
size=cfg.policy.other.replay_buffer.replay_buffer_size)
model = HPT(cfg.policy.model.obs_shape, cfg.policy.model.action_shape).to(device)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)

# Pass the model into Policy
policy = DQNPolicy(cfg.policy, model=model)
Expand Down

0 comments on commit 9608131

Please sign in to comment.