-
Notifications
You must be signed in to change notification settings - Fork 10
/
dqn_play.py
85 lines (77 loc) · 2.98 KB
/
dqn_play.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python3
import gym
import time
import argparse
import numpy as np
import os
import torch
import utils.dqn_model as dqn_model
import utils.wrappers as wrappers
from utils.actions import ArgmaxActionSelector
import utils.utils as utils
from utils.agent import DQNAgent, TargetNet
import collections
DEFAULT_ENV_NAME = "PongNoFrameskip-v4"
BOXING_ENV_NAME = "BoxingNoFrameskip-v4"
FPS = 25
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# parser.add_argument("-m", "--model", required=True, help="Model file to load")
parser.add_argument("-e", "--env", default=BOXING_ENV_NAME,
help="Environment name to use, default=" + DEFAULT_ENV_NAME)
parser.add_argument("-r", "--record", help="Directory to store video recording")
parser.add_argument("--no-visualize", default=True, action='store_false', dest='visualize',
help="Disable visualization of the game play")
args = parser.parse_args()
use_dueling = True
env = gym.make(args.env)
env = wrappers.wrap_dqn(env)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# env = wrappers.make_env(DEFAULT_ENV_NAME)
if args.record:
env = gym.wrappers.Monitor(env, args.record)
if use_dueling:
net = dqn_model.DuelingLSDQN(env.observation_space.shape, env.action_space.n).to(device)
else:
net = dqn_model.LSDQN(env.observation_space.shape, env.action_space.n).to(device)
# net.load_state_dict(torch.load(args.model, map_location=lambda storage, loc: storage))
path_to_model_ckpt = './agent_ckpt/agent_ls_dqn_-boxing.pth'
exists = os.path.isfile(path_to_model_ckpt)
if exists:
if not torch.cuda.is_available():
checkpoint = torch.load(path_to_model_ckpt, map_location='cpu')
else:
checkpoint = torch.load(path_to_model_ckpt)
net.load_state_dict(checkpoint['model_state_dict'])
print("Loaded checkpoint from ", path_to_model_ckpt)
else:
raise SystemExit("Checkpoint File Not Found")
selector = ArgmaxActionSelector()
agent = DQNAgent(net, selector, device=device)
state = env.reset()
total_reward = 0.0
c = collections.Counter()
while True:
start_ts = time.time()
if args.visualize:
env.render()
# state_v = torch.tensor(np.array([state], copy=False))
# state_v = ptan.agent.default_states_preprocessor(state)
# q_vals = net(state_v).data.numpy()[0]
# action = np.argmax(q_vals)
action, _ = agent([state])
# print(action)
c[action[0]] += 1
state, reward, done, _ = env.step(action)
total_reward += reward
if done:
env.close()
break
if args.visualize:
delta = 1/FPS - (time.time() - start_ts)
if delta > 0:
time.sleep(delta)
print("Total reward: %.2f" % total_reward)
print("Action counts:", c)
if args.record:
env.env.close()