-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_ddqn.py
executable file
·131 lines (111 loc) · 3.51 KB
/
train_ddqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import random
from pathlib import Path
import gymnasium as gym
import gym_PBN
import numpy as np
import torch
from gym_PBN.utils.eval import compute_ssd_hist
import wandb
from ddqn_per import DDQNPER
model_cls = DDQNPER
model_name = "DDQNPER"
# Parse settings
parser = argparse.ArgumentParser(description="Train an RL model for target control.")
parser.add_argument(
"--time-steps", metavar="N", type=int, help="Total number of training time steps."
)
parser.add_argument(
"--learning-starts", type=int, metavar="LS", help="when the learning starts"
)
parser.add_argument(
"--seed", type=int, default=42, metavar="S", help="random seed (default: 42)."
)
parser.add_argument("--env", type=str, help="the environment to run.")
parser.add_argument(
"--resume-training",
action="store_true",
help="resume training from latest checkpoint.",
)
parser.add_argument("--checkpoint-dir", default="models", help="path to save models")
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--eval-only", action="store_true", default=False, help="evaluate only"
)
parser.add_argument(
"--exp-name", type=str, default="ddqn", metavar="E", help="the experiment name."
)
parser.add_argument("--log-dir", default="logs", help="path to save logs")
parser.add_argument(
"--hyperparams", type=str, help="any extra hyper parameters for the model"
)
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
DEVICE = torch.device("cuda" if use_cuda else "cpu")
print(f"Training on {DEVICE}")
# Reproducibility
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
# Load env
env = gym.make(args.env)
# set up logs
TOP_LEVEL_LOG_DIR = Path(args.log_dir)
TOP_LEVEL_LOG_DIR.mkdir(parents=True, exist_ok=True)
RUN_NAME = f"{args.env.split('/')[-1]}_{args.exp_name}_{args.seed}"
# Checkpoints
checkpoint_path = Path(args.checkpoint_dir) / RUN_NAME
checkpoint_path.mkdir(parents=True, exist_ok=True)
def get_latest_checkpoint():
files = list(checkpoint_path.glob("*.pt"))
if len(files) > 0:
return max(files, key=lambda x: x.stat().st_ctime)
else:
return None
# Model
total_time_steps = args.time_steps
resume_steps = 0
hyperparams = {}
if args.hyperparams:
hyperparams = {
param.split("=")[0]: eval(param.split("=")[1])
for param in args.hyperparams.split(",")
}
model = DDQNPER(env, DEVICE, **hyperparams)
config = model.get_config()
config["learning_starts"] = args.learning_starts
run = wandb.init(
project="pbn-rl",
entity="uos-plccn",
sync_tensorboard=True,
monitor_gym=True,
config=config,
name=RUN_NAME,
save_code=True,
)
if args.resume_training:
model_path = get_latest_checkpoint()
if model_path:
print(f"Loading model {model_path}.")
model = model_cls.load(model_path, env, device=DEVICE)
resume_steps = total_time_steps - model.num_timesteps
if not args.eval_only:
print(f"Training for {total_time_steps - resume_steps} time steps...")
model.learn(
total_time_steps,
learning_starts=args.learning_starts,
checkpoint_freq=10_000,
checkpoint_path=checkpoint_path,
resume_steps=resume_steps,
log_dir=TOP_LEVEL_LOG_DIR,
log_name=RUN_NAME,
log=True,
run=run,
)
print(f"Evaluating...")
ssd, plot = compute_ssd_hist(env, model, resets=300, iters=100_000, multiprocess=False)
run.log({"SSD": plot})
env.close()
run.finish()