forked from rail-berkeley/rlkit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dqn_and_double_dqn.py
56 lines (49 loc) · 1.37 KB
/
dqn_and_double_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Run DQN on grid world.
"""
import gym
import numpy as np
from torch import nn as nn
import rlkit.torch.pytorch_util as ptu
from rlkit.launchers.launcher_util import setup_logger
from rlkit.torch.dqn.dqn import DQN
from rlkit.torch.networks import Mlp
def experiment(variant):
env = gym.make('CartPole-v0')
training_env = gym.make('CartPole-v0')
qf = Mlp(
hidden_sizes=[32, 32],
input_size=int(np.prod(env.observation_space.shape)),
output_size=env.action_space.n,
)
qf_criterion = nn.MSELoss()
# Use this to switch to DoubleDQN
# algorithm = DoubleDQN(
algorithm = DQN(
env,
training_env=training_env,
qf=qf,
qf_criterion=qf_criterion,
**variant['algo_params']
)
if ptu.gpu_enabled():
algorithm.cuda()
algorithm.train()
if __name__ == "__main__":
# noinspection PyTypeChecker
variant = dict(
algo_params=dict(
num_epochs=500,
num_steps_per_epoch=1000,
num_steps_per_eval=1000,
batch_size=128,
max_path_length=200,
discount=0.99,
epsilon=0.2,
tau=0.001,
hard_update_period=1000,
save_environment=False, # Can't serialize CartPole for some reason
),
)
setup_logger('name-of-experiment', variant=variant)
experiment(variant)