-
Notifications
You must be signed in to change notification settings - Fork 378
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(xyy):add HPT model to implement PolicyStem+DuelingHead #841
base: main
Are you sure you want to change the base?
Conversation
@@ -24,6 +24,8 @@ | |||
from .vae import VanillaVAE | |||
from .decision_transformer import DecisionTransformer | |||
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS | |||
from .hpt import HPT | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optimize import order
|
||
|
||
class PolicyStem(nn.Module): | ||
"""policy stem |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reformat the docstring as the DI-engine style
@@ -32,8 +34,13 @@ def main(): | |||
|
|||
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) | |||
|
|||
model = DQN(**cfg.policy.model) | |||
# # Migrating models to the GPU | |||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
format comments
create_config = lunarlander_hpt_create_config | ||
|
||
if __name__ == "__main__": | ||
# or you can enter `ding -m serial -c lunarlander_dqn_config.py -s 0` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the comments
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], | ||
), | ||
env_manager=dict(type='subprocess'), | ||
# env_manager=dict(type='base'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move unused comments
@MODEL_REGISTRY.register('hpt') | ||
class HPT(nn.Module): | ||
|
||
def __init__(self, state_dim, action_dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add overview and related introduction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add unittest like other template in DI-engine
policy=dict( | ||
# Whether to use cuda for network. | ||
cuda=True, | ||
load_path="./lunarlander_hpt_seed0/ckpt/ckpt_best.pth.tar", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unused part
|
||
# Migrating models to the GPU | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model = HPT(cfg.policy.model.obs_shape, cfg.policy.model.action_shape).to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments about the difference of HPT from normal model
Description
Here are some tensorboard plots from the lunarlander_hpt_example.py run.
Related Issue
TODO
Check List