Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
huangzhengxiang committed Jun 7, 2023
2 parents a90dfac + 381ffc3 commit e7cdc90
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
5 changes: 3 additions & 2 deletions config/PongNoFrameskip-v4/DQN.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mtype: "DQN"
variant: "ddqn"
variant: "dddqn"
cnn: [64,64,32]
final: 512
image: [84,84,1]
Expand All @@ -11,6 +11,7 @@ skip_first_frames: 10
buffer_size: 30000
begin_noise: 0.6
end_noise: 0.1
lr: 0.00025
lr: 0.0005
weight_decay: 0.001
batch_size: 32
device: "cpu"
10 changes: 6 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,13 @@ def __init__(self,config,state_dim,action_space) -> None:
self.targetDQNet=DQNet(self.state_dim,self.config["mlp"],self.action_space.shape[0],"relu",None)
self.targetDQNet.load_state_dict(self.DQNet.state_dict().copy())
else:
if "device" in self.config:
self.device=self.config["device"] # default "cuda"
self.buffer=ReplayBuffer(config["buffer_size"], self.state_dim)
self.DQNet=ConvDQNet(self.config["history"],self.config["cnn"],self.config["final"],self.action_space.shape[0],"relu",None,self.variant=="dueling" or self.variant=="dddqn","cuda")
self.targetDQNet=ConvDQNet(self.config["history"],self.config["cnn"],self.config["final"],self.action_space.shape[0],"relu",None,self.variant=="dueling"or self.variant=="dddqn","cuda")
self.DQNet.cuda()
self.targetDQNet.cuda()
self.DQNet=ConvDQNet(self.config["history"],self.config["cnn"],self.config["final"],self.action_space.shape[0],"relu",None,self.variant=="dueling" or self.variant=="dddqn",self.device)
self.targetDQNet=ConvDQNet(self.config["history"],self.config["cnn"],self.config["final"],self.action_space.shape[0],"relu",None,self.variant=="dueling"or self.variant=="dddqn",self.device)
self.DQNet.to(device=self.device)
self.targetDQNet.to(device=self.device)
self.targetDQNet.load_state_dict(self.DQNet.state_dict().copy())
self.mseLoss=nn.MSELoss()
self.DQNOptimizer=torch.optim.Adam(self.DQNet.parameters(),
Expand Down

0 comments on commit e7cdc90

Please sign in to comment.