-
Notifications
You must be signed in to change notification settings - Fork 1
/
trainer.py
90 lines (76 loc) · 3.16 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from collections import deque
from random import sample
import numpy as np
from tqdm import trange
class Trainer:
"""Class to handle training"""
def __init__(self, eps=500, pre=0, add=0, prev=0):
"""Initialize trainer class
Args:
eps (int, optional): Number of episodes to train for. Defaults to 500.
pre (int, optional): Number of episodes to pretrain for. Defaults to 0.
add (int, optional): Number of additional episodes to train hindsight for between each full agent updates. Defaults to 0.
prev (int, optional): Number of past episodes to sample from for additional training. If 0, then new episodes are sampled. Defaults to 0.
"""
self.eps = eps
self.pre = pre
self.add = add
self.prev_episodes = deque(maxlen=prev)
def run_episode(self, agent, env) -> tuple:
"""Run an episode using the given agent in the given environment
Args:
agent: Agent to run episode with
env: Environment to run episode in
Returns:
tuple: Observation vectors (numpy array), selected actions (list), rewards (list)
"""
states = []
actions = []
rewards = []
done = False
s, _ = env.reset()
while not done:
a = agent.act(s)
ns, r, done, _ = env.step(a)
states.append(s)
actions.append(a)
rewards.append(r)
s = ns
return np.array(states), np.array(actions), np.array(rewards)
def additional_episodes(self, agent, env) -> list:
"""Sample episodes for the additional training of the hindsight estimator
Args:
agent: Agent to use if new episodes are sampled
env: Environment to use if new episodes are sampled
Returns:
list: Sampled episodes
"""
if len(self.prev_episodes) >= self.add:
return sample(self.prev_episodes, k=self.add)
elif not self.prev_episodes:
return [self.run_episode(agent, env) for _ in range(self.add)]
else:
return []
def fit(self, agent, env) -> list:
"""Train the given agent on the given environment
Args:
agent: Agent to train
env: Environment to train on
Returns:
list: Achieved returns
"""
returns = []
# Pretrain hindsight
for _ in range(self.pre):
states, actions, rewards = self.run_episode(agent, env)
agent.update(states, actions, rewards, train_actor=False)
# Train actor and hindsight
for ep in trange(self.eps, leave=False):
states, actions, rewards = self.run_episode(agent, env)
agent.update(states, actions, rewards)
returns.append(rewards[-1] if env.spec.id == 'DelayedEffect-v0' else sum(rewards))
# Additional training for hindsight
self.prev_episodes.append((states, actions, rewards))
for states, actions, rewards in self.additional_episodes(agent, env):
agent.update(states, actions, rewards, train_actor=False)
return returns