-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_vlog_mahjong.py
160 lines (116 loc) · 5.15 KB
/
eval_vlog_mahjong.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gc
import os
import time
import sys
import logging
import warnings
from models import VLOG
import numpy as np
import scipy.io as sio
import torch
import argparse
import MahjongPy as mp
from env_mahjong import *
def test_performance(agents_test, env_test, times=80, use_posterior=None, rotate_player=False):
if use_posterior is None:
use_posterior = [False, False, False, False]
if times % 4 != 0:
warnings.warn("For fairness, times should be 4N where N is positive integer so that each player acts as Oya for the same times!")
assert len(agents_test) == 4
steps_taken = 0
max_steps = 1000
# -------------- Return vs. pretrained player -------------
game = 0
trials = 0
payoffs_array = []
winning_counts = []
deal_in_counts = []
while game < times:
trials += 1
if rotate_player:
env_test.reset(0, 'east')
else:
env_test.reset(game % 4, 'east')
payoffs = np.zeros([4])
for tt in range(max_steps):
curr_pid = env_test.get_curr_player_id()
valid_actions = env_test.get_valid_actions(nhot=False)
if rotate_player:
agent_id = (curr_pid + game) % 4
else:
agent_id = curr_pid
action_mask = env_test.get_valid_actions(nhot=True)
a = agents_test[agent_id].select(env_test.get_obs(curr_pid), env_test.get_oracle_obs(curr_pid),
action_mask, greedy=True,
use_posterior=use_posterior[agent_id])
sp, r, done, _ = env_test.step(curr_pid, a)
steps_taken += 1
if env_test.has_done():
if rotate_player:
id_shift = np.array([(0 - game) % 4, (1 - game) % 4, (2 - game) % 4, (3 - game) % 4], dtype=int)
payoffs = payoffs + np.array(env_test.get_payoffs())[id_shift]
else:
payoffs = payoffs + np.array(env_test.get_payoffs())
print("payoffs of game {} :".format(game), payoffs)
curr_wins = np.zeros([4])
curr_deal_ins = np.zeros([4])
if env_test.t.get_result().result_type == mp.ResultType.RonAgari:
for ii in range(4): # consider multiple players Agari
if payoffs[ii] > 0:
curr_wins[ii] = 1
curr_deal_ins[np.argmin(payoffs)] = 1
elif env_test.t.get_result().result_type == mp.ResultType.TsumoAgari:
curr_wins[np.argmax(payoffs)] = 1
payoffs_array.append(payoffs)
winning_counts.append(curr_wins)
deal_in_counts.append(curr_deal_ins)
game += 1
break
print("Test {} games".format(times))
return np.array(payoffs_array), steps_taken, times, np.array(winning_counts), np.array(deal_in_counts)
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument('--env', type=int, default=1, help="ID of Environment")
parser.add_argument('--model_dir_0', type=str, help="Directory of the model for player 0 on the mahjong table")
parser.add_argument('--model_dir_1', type=str, help="Directory of the model for player 1 on the mahjong table")
parser.add_argument('--model_dir_2', type=str, help="Directory of the model for player 2 on the mahjong table")
parser.add_argument('--model_dir_3', type=str, help="Directory of the model for player 3 on the mahjong table")
parser.add_argument('--seed', type=int, default=0, help="Random seed")
parser.add_argument('--times', type=int, default=100, help="number of games tested")
args = parser.parse_args()
savepath = './data/'
if os.path.exists(savepath):
logging.info('{} exists (possibly so do data).'.format(savepath))
else:
os.makedirs(savepath)
seed = args.seed
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 1:
device = torch.device("cuda:{}".format(seed % torch.cuda.device_count()) if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(seed % torch.cuda.device_count())
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rotate_player = True
model_dirs = [args.model_dir_0,
args.model_dir_1,
args.model_dir_2,
args.model_dir_3]
test_model_dirs = [model_dirs[0],
model_dirs[1],
model_dirs[2],
model_dirs[3]]
agents_test = [torch.load(model_dir, map_location=torch.device(device)) for model_dir in test_model_dirs]
for ag in agents_test:
ag.device = device
env_name = "mahjong"
env_test = EnvMahjong()
results_array, total_steps, n_games, winnings, dealins = test_performance(
agents_test, env_test, times=args.times, rotate_player=rotate_player)
data = {"results": results_array,
"winnings": winnings,
"dealins": dealins,
"total_steps": total_steps,
"n_games": n_games,
"model_dirs": model_dirs}
sio.savemat(savepath + env_name + "_" + "evaluation" + "_{}".format(seed) + ".mat", data, long_field_names=True)