Skip to content

Commit

Permalink
feature(xyy):add HPT model and test_hpt
Browse files Browse the repository at this point in the history
  • Loading branch information
luodi-7 committed Dec 4, 2024
1 parent 26f4c97 commit 79aa427
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions ding/model/template/hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def __init__(self, state_dim: int, action_dim: int):
def forward(self, x: torch.Tensor):
"""
Overview:
Forward pass of the HPT model. Computes latent tokens from the input state and passes them through the Dueling Head.
Forward pass of the HPT model.
Computes latent tokens from the input state and passes them through the Dueling Head.
Arguments:
- x (:obj:`torch.Tensor`): The input tensor representing the state.
Expand All @@ -65,7 +66,8 @@ def forward(self, x: torch.Tensor):
class PolicyStem(nn.Module):
"""
Overview:
The Policy Stem module is responsible for processing input features and generating latent tokens using a cross-attention mechanism.
The Policy Stem module is responsible for processing input features
and generating latent tokens using a cross-attention mechanism.
It extracts features from the input and then applies cross-attention to generate a set of latent tokens.
Interfaces:
Expand Down Expand Up @@ -151,7 +153,8 @@ def device(self):
class CrossAttention(nn.Module):
"""
Overview:
CrossAttention module used in the Perceiver IO model. It computes the attention between the query and context tensors,
CrossAttention module used in the Perceiver IO model.
It computes the attention between the query and context tensors,
and returns the output tensor after applying attention.
Arguments:
Expand All @@ -177,7 +180,8 @@ def __init__(self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout:
def forward(self, x: torch.Tensor, context: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Overview:
Forward pass of the CrossAttention module. Computes the attention between the query and context tensors.
Forward pass of the CrossAttention module.
Computes the attention between the query and context tensors.
Arguments:
- x (:obj:`torch.Tensor`): The query input tensor.
Expand Down
2 changes: 1 addition & 1 deletion dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def main():

# Migrating models to the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# HPT introduces a Policy Stem module, which processes the input features using Cross-Attention and generates a set of latent tokens.
# HPT introduces a Policy Stem module, which processes the input features using Cross-Attention.
model = HPT(cfg.policy.model.obs_shape, cfg.policy.model.action_shape).to(device)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)

Expand Down

0 comments on commit 79aa427

Please sign in to comment.