-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtdlambda.py
59 lines (52 loc) · 1.7 KB
/
tdlambda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
'''
Implementation of standard TD-lambda algorithm for the main value function.
'''
import jax
import jax.numpy as jnp
class TDLambda(object):
def __init__(self,
first_x_w,
discount=0.0,
alpha_w=1.e-4,
lambda_w=0.95):
self.last_x_w = first_x_w
self.alpha_w = alpha_w
self.lambda_w = lambda_w
self.w = jnp.zeros_like(self.last_x_w)
self.z = jnp.zeros_like(self.last_x_w)
self.discount = discount
self.td_error = None
def predict(self, x_w):
'''
Inputs:
x_w: state representation
Outputs:
Main value function evaluated at x_w
'''
return jnp.dot(self.w, x_w)
def update(self, x_w, reward):
'''
Inputs:
x_w: state representation
What the function does:
Updates the main value weights, eligibility trace vectors and td error.
Updates the current features.
Outputs:
N/A
'''
self.w, self.z = _update(last_x_w=self.last_x_w,
reward=reward,
x_w=x_w,
w=self.w,
z=self.z,
discount=self.discount,
alpha_w=self.alpha_w,
lambda_w=self.lambda_w)
self.last_x_w = x_w
@jax.jit
def _update(last_x_w, reward, x_w, w, z, discount, alpha_w, lambda_w):
td_error = reward + discount * jnp.dot(w, x_w) - jnp.dot(w, last_x_w)
z *= discount * lambda_w
z += last_x_w
w += alpha_w * td_error * z
return w, z