-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlearn.py
95 lines (85 loc) · 3.86 KB
/
learn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from bridge import BridgePosition
import MCTS
import numpy as np
import tensorflow as tf
import DNN
import time
feature_size = len(BridgePosition().to_tensor())
policy_num = 52
feature_columns = [tf.feature_column.numeric_column("x", shape = [feature_size])]
classifier = tf.estimator.Estimator(model_fn = DNN.model_fn,
params = {
'feature_columns': feature_columns,
'hidden_units': [1024, 1024, 1024],
'n_classes': policy_num,
'value_weight': 0.1
})
def self_play(visualize = False):
tree = MCTS.init_tree(BridgePosition())
data_x = []
data_p = []
data_v = []
data_v_sign = []
if visualize: tree[0].visualize()
while len(tree[0].get_moves()) > 0:
for k in range(100): MCTS.MCTS_iteration(tree)
n_max = -1
for move, node in tree[-1].items():
if node[1] > n_max:
n_max = node[1]
move_max = move
data_x.append(tree[0].to_tensor())
data_p.append([tree[0].move_to_int(move_max)])
data_v.append([-tree[0].score - tree[0].tricks_left() / 2])
data_v_sign.append([1 if tree[0].current_player % 2 == 0 else -1])
if visualize:
card_str = tree[0].move_to_str(move_max)
print(card_str, end = ' ' if tree[0].cards_in_trick < 3 else '\n', flush = True)
tree = tree[-1][move_max]
return tree[0].score, [np.array(data_x), np.array(data_p),
(np.array(data_v) + tree[0].score) * np.array(data_v_sign)]
import pickle
import os
if os.path.exists('data.pickle'):
with open('data.pickle', 'rb') as fin: data = pickle.load(fin)
else:
data = [np.zeros((0, feature_size)), np.zeros((0, 1), dtype = int), np.zeros((0, 1))] # x, p, v
for i in range(1000):
t0 = time.process_time()
score, data_game = self_play(visualize = i % 100 == 0)
print("Game #%d time:" % i, time.process_time() - t0)
for k in range(3): data[k] = np.vstack([data[k], data_game[k]])
with open('data.pickle', 'wb') as fout: pickle.dump(data, fout)
train_num = data[0].shape[0] // 10 * 9
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x = {'x': data[0][:train_num]}, y = {'p': data[1][:train_num], 'v': data[2][:train_num]},
num_epochs = None, shuffle = True)
test_input_fn = tf.estimator.inputs.numpy_input_fn(
x = {'x': data[0][train_num:]}, y = {'p': data[1][train_num:], 'v': data[2][train_num:]},
num_epochs = 1, shuffle = False)
for k in range(10):
t0 = time.process_time()
classifier.train(input_fn = train_input_fn, steps = 200)
print("Train time:", time.process_time() - t0)
t0 = time.process_time()
metrics = classifier.evaluate(input_fn = test_input_fn)
print("Test time:", time.process_time() - t0)
t0 = time.process_time()
for i in range(train_num, train_num + 52):
pos = BridgePosition.from_tensor(data[0][i])
policy, value = data[1][i][0], data[2][i][0]
pos.visualize(verbose = True)
print(pos.move_to_str(pos.int_to_move(policy)), value + pos.tricks_left() / 2)
pred = classifier.predict(input_fn = tf.estimator.inputs.numpy_input_fn(
x = {'x': data[0][i:i+1]}, y = None,
num_epochs = 1, shuffle = False, batch_size = 1)
)
t1 = time.process_time()
pred = next(pred)
print("Time:", time.process_time() - t1)
pred_policy, pred_value = pred['class_ids'][0], pred['val'][0]
print(pos.move_to_str(pos.int_to_move(pred_policy)), pred_value + pos.tricks_left() / 2)
print()
print("Prediction time:", time.process_time() - t0)
accuracy_score, mse_value = metrics['accuracy'], metrics['mse_value']
print(accuracy_score, mse_value)