Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(xyy):add HPT model to implement PolicyStem+DuelingHead #841

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

luodi-7
Copy link

@luodi-7 luodi-7 commented Nov 27, 2024

Description

Here are some tensorboard plots from the lunarlander_hpt_example.py run.
hpt_episode_return
hpt_train_q_value
hpt_target_q_value
hpt_train_total_loss

Related Issue

TODO

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@PaParaZz1 PaParaZz1 added the algo Add new algorithm or improve old one label Nov 28, 2024
@@ -24,6 +24,8 @@
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
from .hpt import HPT

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optimize import order



class PolicyStem(nn.Module):
"""policy stem
Copy link
Collaborator

@puyuan1996 puyuan1996 Nov 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reformat the docstring as the DI-engine style

@@ -32,8 +34,13 @@ def main():

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = DQN(**cfg.policy.model)
# # Migrating models to the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format comments

create_config = lunarlander_hpt_create_config

if __name__ == "__main__":
# or you can enter `ding -m serial -c lunarlander_dqn_config.py -s 0`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the comments

import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
# env_manager=dict(type='base'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move unused comments

@MODEL_REGISTRY.register('hpt')
class HPT(nn.Module):

def __init__(self, state_dim, action_dim):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add overview and related introduction

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add unittest like other template in DI-engine

policy=dict(
# Whether to use cuda for network.
cuda=True,
load_path="./lunarlander_hpt_seed0/ckpt/ckpt_best.pth.tar",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unused part


# Migrating models to the GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HPT(cfg.policy.model.obs_shape, cfg.policy.model.action_shape).to(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments about the difference of HPT from normal model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
algo Add new algorithm or improve old one
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants