-
Notifications
You must be signed in to change notification settings - Fork 20
/
q.py
73 lines (59 loc) · 2.29 KB
/
q.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# coding=utf-8
import pandas as pd
import numpy as np
from base.maze import Maze
class QLearning(object):
def __init__(self, actions, env, alpha=0.01, gamma=0.9, epsilon=0.9):
self.actions = actions
self.env = env
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64)
def get_action(self, state):
self.check_if_state_exist(state)
if np.random.uniform() < self.epsilon:
target_actions = self.q_table.loc[state, :]
target_actions = target_actions.reindex(np.random.permutation(target_actions.index))
target_action = target_actions.idxmax()
else:
target_action = np.random.choice(self.actions)
return target_action
def update_q_value(self, state, action, reward, state_next):
self.check_if_state_exist(state_next)
q_value_predict = self.q_table.loc[state, action]
if state_next != 'done':
q_value_real = reward + self.gamma * self.q_table.loc[state_next, :].max()
else:
q_value_real = reward
self.q_table.loc[state, action] += self.alpha * (q_value_real - q_value_predict)
def check_if_state_exist(self, state):
if state not in self.q_table.index:
self.q_table = self.q_table.append(
pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state
)
)
def train(self):
for episode in range(100):
print('Episode: {}'.format(episode))
state = self.env.reset()
while True:
self.env.render()
# Get next action.
action = self.get_action(str(state))
# Get next state.
state_next, reward, done = self.env.step(action)
# Update Q table.
self.update_q_value(str(state), action, reward, str(state_next))
state = state_next
if done:
break
self.env.destroy()
if __name__ == '__main__':
env = Maze()
model = QLearning(list(range(env.n_actions)), env)
env.after(100, model.train)
env.mainloop()