-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
69 lines (60 loc) · 1.88 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from argparse import ArgumentParser
from itertools import count
import pygame
from src.agent import QAgent
from src.constants import *
from src.game import Game
def run(
episodes,
save_model,
epsilon,
print_every,
save_plot,
suffix,
save_rewards,
min_epsilon,
epsilon_decay,
):
env = Game(randomize_start_pos=False)
agent = QAgent()
agent.train(env, episodes, epsilon, print_every, min_epsilon, epsilon_decay)
if save_model:
agent.save_model(f"models/Q_model_e{episodes}_epsilon{epsilon}_{suffix}.pkl")
if save_plot:
agent.plot_train_stats(
f"plots/Q_model_e{episodes}_epsilon{epsilon}_{suffix}.png"
)
if save_rewards:
agent.save_stats(f"rewards/Q_e{episodes}_epsilon{epsilon}_{suffix}.csv")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--episodes", type=int, default=100)
parser.add_argument("--epsilon", type=float, default=0.001)
parser.add_argument("--print_every", type=int, default=10)
parser.add_argument("--suffix", default="")
parser.add_argument("--save_model", action="store_true")
parser.add_argument("--save_plot", action="store_true")
parser.add_argument("--save_rewards", action="store_true")
parser.add_argument("--min_epsilon", type=float, default=0)
parser.add_argument("--epsilon_decay", type=float, default=0.995)
args = parser.parse_args()
episodes = args.episodes
save_model = args.save_model
epsilon = args.epsilon
print_every = args.print_every
save_plot = args.save_plot
suffix = args.suffix
save_rewards = args.save_rewards
min_epsilon = args.min_epsilon
epsilon_decay = args.epsilon_decay
run(
episodes,
save_model,
epsilon,
print_every,
save_plot,
suffix,
save_rewards,
min_epsilon,
epsilon_decay,
)