-
Notifications
You must be signed in to change notification settings - Fork 0
/
mcts.py
129 lines (95 loc) · 3.87 KB
/
mcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from player import Player
from cache import Cache1, Cache2
from board import Board, Result, Cell
from math import sqrt, inf, log
from tqdm import trange
from time import sleep
SQUARE_ROOT_2 = sqrt(2)
PLAYOUTS = 5000
class Mcts(Player):
"""docstring for Mcts"""
class Node(object):
"""docstring for Node"""
def __init__(self):
super(Mcts.Node, self).__init__()
self.wins = 0
self.draws = 0
self.losses = 0
self.visits = 0
self.parents = Cache1()
def upper_confidence_bound(self):
if self.visits == 0:
return inf
exploitation_component = (self.wins - self.losses) / self.visits
parent_visits = self.parent_visits()
exploration_component = SQUARE_ROOT_2 * sqrt(log(parent_visits) / self.visits)
return exploitation_component + exploration_component
def parent_visits(self):
return sum(parent.visits for parent in self.parents.boards.values())
def register_parent(self, parent_board, parent_node):
node, found = self.parents.get(parent_board)
if found:
return
self.parents.set(parent_board, parent_node)
def __init__(self):
super(Mcts, self).__init__("Monte Carlo Tree Search")
self.nodes = Cache1()
def get_best_move(self, board):
current_node, found = self.get_node(board)
if not found:
current_node = self.create_node(board)
move_child_node_pairs = self.get_move_child_node_pairs(board)
# Forward propagation, create tree structure
best_move, best_node = move_child_node_pairs[0]
for move, node in move_child_node_pairs:
node.register_parent(board, current_node)
if node.upper_confidence_bound() > best_node.upper_confidence_bound():
best_move, best_node = move, node
return best_move
def get_node(self, board):
return self.nodes.get(board)
def create_node(self, board):
new_node = Mcts.Node()
self.nodes.set(board, new_node)
return new_node
def get_move_child_node_pairs(self, board):
return [(move, self.get_child_node(move, board))
for move in board.get_valid_moves()]
def get_child_node(self, move, board):
new_board = board.simulate_turn(move)
cached_node, found = self.get_node(new_board)
if found:
return cached_node
return self.create_node(new_board)
def train(self, board=Board(), playouts=PLAYOUTS):
print(f"Performing {playouts} playouts.")
sleep(0.05)
for _ in trange(playouts):
self.playout(board)
def playout(self, board):
history = [board]
while not board.is_game_over():
move = self.get_best_move(board)
board = board.simulate_turn(move)
history.append(board)
result = board.get_game_result()
self.backpropagate(history, result)
def backpropagate(self, history, game_result):
for board in history:
node, found = self.get_node(board)
assert found is True, "Node must exist"
node.visits += 1
if self.is_win(board.whose_turn(), game_result):
node.wins += 1
elif self.is_loss(board.whose_turn(), game_result):
node.losses += 1
elif game_result == Result.Draw:
node.draws += 1
else:
raise ValueError("Illegal game state.")
def is_win(self, turn, result):
return turn == Cell.X and result == Result.O_Wins or \
turn == Cell.O and result == Result.X_Wins
def is_loss(self, turn, result):
return turn == Cell.X and result == Result.X_Wins or \
turn == Cell.O and result == Result.O_Wins