-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
456 lines (391 loc) · 15.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
# Copyright 2023 The Pgx Authors. All Rights Reserved.
# Copyright 2024 Samir Rangwalla
import datetime
import os
import pickle
import time
from functools import partial
from typing import NamedTuple
import mctx
import optax
import pgx
from pgx.experimental import auto_reset
from pydantic import BaseModel
# We referred to Haiku's ResNet implementation:
# https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/nets/resnet.py
import haiku as hk
import jax
import jax.numpy as jnp
from tqdm import tqdm
class Config(BaseModel):
env_id: pgx.EnvId = input("envId:")
seed: int = 0
max_num_iters: int = 1
# network params
num_channels: int = 128
num_layers: int = 6
resnet_v2: bool = True
# selfplay params
#Testing
selfplay_batch_size: int = 1028
num_simulations: int = 32
max_num_steps: int = 256
# training params
training_batch_size: int = 4096
learning_rate: float = 0.001
# eval params
eval_interval: int = 5
class Config:
extra = "forbid"
config: Config = Config()
print(config)
class BlockV2(hk.Module):
def __init__(self, num_channels, name="BlockV2"):
super(BlockV2, self).__init__(name=name)
self.num_channels = num_channels
def __call__(self, x, is_training, test_local_stats):
i = x
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = jax.nn.relu(x)
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = jax.nn.relu(x)
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
return x + i
class AZNet(hk.Module):
"""AlphaZero NN architecture."""
def __init__(
self,
num_actions,
num_channels: int = 64,
num_blocks: int = 5,
name="az_net",
):
super().__init__(name=name)
self.num_actions = num_actions
self.num_channels = num_channels
self.num_blocks = num_blocks
self.resnet_cls = BlockV2
def __call__(self, x, is_training, test_local_stats):
if config.env_id == "kuhn_poker" or config.env_id == "leduc_holdem":
x = x.reshape((x.shape[0], x.shape[1], 1))
x = x.astype(jnp.float32)
x = hk.Conv2D(self.num_channels, kernel_shape=2)(x)
for i in range(self.num_blocks):
x = self.resnet_cls(self.num_channels, name=f"block_{i}")(
x, is_training, test_local_stats
)
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = jax.nn.relu(x)
# policy head
logits = hk.Conv2D(output_channels=2, kernel_shape=1)(x)
logits = hk.BatchNorm(True, True, 0.9)(logits, is_training, test_local_stats)
logits = jax.nn.relu(logits)
logits = hk.Flatten()(logits)
logits = hk.Linear(self.num_actions)(logits)
# value head
v = hk.Conv2D(output_channels=1, kernel_shape=1)(x)
v = hk.BatchNorm(True, True, 0.9)(v, is_training, test_local_stats)
v = jax.nn.relu(v)
v = hk.Flatten()(v)
v = hk.Linear(self.num_channels)(v)
v = jax.nn.relu(v)
v = hk.Linear(1)(v)
v = jnp.tanh(v)
v = v.reshape((-1,))
return logits, v
devices = jax.local_devices()
num_devices = len(devices)
env = pgx.make(config.env_id)
#baseline = pgx.make_baseline_model(config.env_id + "_v0")
def forward_fn(x, is_eval=False):
net = AZNet(
num_actions=env.num_actions,
num_channels=config.num_channels,
num_blocks=config.num_layers,
)
policy_out, value_out = net(x, is_training=not is_eval, test_local_stats=False)
return policy_out, value_out
forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
optimizer = optax.adam(learning_rate=config.learning_rate)
def recurrent_fn(model, rng_key: jnp.ndarray, action: jnp.ndarray, state: pgx.State):
# model: params
# state: embedding
if config.env_id not in (
"minatar-asterix",
"minatar-breakout",
"minatar-freeway",
"minatar-seaquest",
"minatar-space_invaders",
"2048"
):
del rng_key
if config.env_id in (
"minatar-asterix",
"minatar-breakout",
"minatar-freeway",
"minatar-seaquest",
"minatar-space_invaders",
"2048"
):
step_fn = jax.vmap(env.step)
keys = jax.random.split(rng_key, state.observation.shape[0])
model_params, model_state = model
current_player = state.current_player
if config.env_id in (
"minatar-asterix",
"minatar-breakout",
"minatar-freeway",
"minatar-seaquest",
"minatar-space_invaders",
"2048"
):
state = step_fn(state, action, keys)
else:
state = jax.vmap(env.step)(state, action)
(logits, value), _ = forward.apply(model_params, model_state, state.observation, is_eval=True)
# mask invalid actions
logits = logits - jnp.max(logits, axis=-1, keepdims=True)
logits = jnp.where(state.legal_action_mask, logits, jnp.finfo(logits.dtype).min)
reward = state.rewards[jnp.arange(state.rewards.shape[0]), current_player]
value = jnp.where(state.terminated, 0.0, value)
discount = -1.0 * jnp.ones_like(value)
discount = jnp.where(state.terminated, 0.0, discount)
recurrent_fn_output = mctx.RecurrentFnOutput(
reward=reward,
discount=discount,
prior_logits=logits,
value=value,
)
return recurrent_fn_output, state
class SelfplayOutput(NamedTuple):
obs: jnp.ndarray
reward: jnp.ndarray
terminated: jnp.ndarray
action_weights: jnp.ndarray
discount: jnp.ndarray
@jax.pmap
def selfplay(model, rng_key: jnp.ndarray) -> SelfplayOutput:
model_params, model_state = model
batch_size = config.selfplay_batch_size // num_devices
def step_fn(state, key) -> SelfplayOutput:
key1, key2 = jax.random.split(key)
observation = state.observation
(logits, value), _ = forward.apply(
model_params, model_state, state.observation, is_eval=True
)
root = mctx.RootFnOutput(prior_logits=logits, value=value, embedding=state)
policy_output = mctx.gumbel_muzero_policy(
params=model,
rng_key=key1,
root=root,
recurrent_fn=recurrent_fn,
num_simulations=config.num_simulations,
invalid_actions=~state.legal_action_mask,
qtransform=mctx.qtransform_completed_by_mix_value,
gumbel_scale=1.0,
)
actor = state.current_player
keys = jax.random.split(key2, batch_size)
state = jax.vmap(auto_reset(env.step, env.init))(state, policy_output.action, keys)
discount = -1.0 * jnp.ones_like(value)
discount = jnp.where(state.terminated, 0.0, discount)
return state, SelfplayOutput(
obs=observation,
action_weights=policy_output.action_weights,
reward=state.rewards[jnp.arange(state.rewards.shape[0]), actor],
terminated=state.terminated,
discount=discount,
)
# Run selfplay for max_num_steps by batch
rng_key, sub_key = jax.random.split(rng_key)
keys = jax.random.split(sub_key, batch_size)
state = jax.vmap(env.init)(keys)
key_seq = jax.random.split(rng_key, config.max_num_steps)
_, data = jax.lax.scan(step_fn, state, key_seq)
return data
class Sample(NamedTuple):
obs: jnp.ndarray
policy_tgt: jnp.ndarray
value_tgt: jnp.ndarray
mask: jnp.ndarray
@jax.pmap
def compute_loss_input(data: SelfplayOutput) -> Sample:
batch_size = config.selfplay_batch_size // num_devices
# If episode is truncated, there is no value target
# So when we compute value loss, we need to mask it
value_mask = jnp.cumsum(data.terminated[::-1, :], axis=0)[::-1, :] >= 1
# Compute value target
def body_fn(carry, i):
ix = config.max_num_steps - i - 1
v = data.reward[ix] + data.discount[ix] * carry
return v, v
_, value_tgt = jax.lax.scan(
body_fn,
jnp.zeros(batch_size),
jnp.arange(config.max_num_steps),
)
value_tgt = value_tgt[::-1, :]
return Sample(
obs=data.obs,
policy_tgt=data.action_weights,
value_tgt=value_tgt,
mask=value_mask,
)
def loss_fn(model_params, model_state, samples: Sample):
(logits, value), model_state = forward.apply(
model_params, model_state, samples.obs, is_eval=False
)
policy_loss = optax.softmax_cross_entropy(logits, samples.policy_tgt)
policy_loss = jnp.mean(policy_loss)
value_loss = optax.l2_loss(value, samples.value_tgt)
value_loss = jnp.mean(value_loss * samples.mask) # mask if the episode is truncated
return policy_loss + value_loss, (model_state, policy_loss, value_loss)
@partial(jax.pmap, axis_name="i")
def train(model, opt_state, data: Sample):
model_params, model_state = model
grads, (model_state, policy_loss, value_loss) = jax.grad(loss_fn, has_aux=True)(
model_params, model_state, data
)
grads = jax.lax.pmean(grads, axis_name="i")
updates, opt_state = optimizer.update(grads, opt_state)
model_params = optax.apply_updates(model_params, updates)
model = (model_params, model_state)
return model, opt_state, policy_loss, value_loss
@jax.pmap
def evaluate(rng_key, my_model):
"""A simplified evaluation by sampling. Only for debugging.
Please use MCTS and run tournaments for serious evaluation."""
my_player = 0
my_model_parmas, my_model_state = my_model
if config.env_id in (
"minatar-asterix",
"minatar-breakout",
"minatar-freeway",
"minatar-seaquest",
"minatar-space_invaders",
"2048"
):
step_fn = jax.vmap(env.step)
key, subkey = jax.random.split(rng_key)
batch_size = config.selfplay_batch_size // num_devices
keys = jax.random.split(subkey, batch_size)
state = jax.vmap(env.init)(keys)
def body_fn(val):
key, state, R = val
(my_logits, _), _ = forward.apply(
my_model_parmas, my_model_state, state.observation, is_eval=True
)
#opp_logits, _ = baseline(state.observation)
(opp_logits, _), _ = forward.apply(
my_model_parmas, my_model_state, state.observation, is_eval=True
)
is_my_turn = (state.current_player == my_player).reshape((-1, 1))
logits = jnp.where(is_my_turn, my_logits, opp_logits)
key, subkey = jax.random.split(key)
action = jax.random.categorical(subkey, logits, axis=-1)
if config.env_id in (
"minatar-asterix",
"minatar-breakout",
"minatar-freeway",
"minatar-seaquest",
"minatar-space_invaders",
"2048"
):
keys = jax.random.split(subkey, state.observation.shape[0])
state = step_fn(state, action, keys)
else:
state = jax.vmap(env.step)(state, action)
R = R + state.rewards[jnp.arange(batch_size), my_player]
return (key, state, R)
_, _, R = jax.lax.while_loop(
lambda x: ~(x[1].terminated.all()), body_fn, (key, state, jnp.zeros(batch_size))
)
return R
if __name__ == "__main__":
# Initialize model and opt_state
dummy_state = jax.vmap(env.init)(jax.random.split(jax.random.PRNGKey(0), 2))
dummy_input = dummy_state.observation
model = forward.init(jax.random.PRNGKey(0), dummy_input) # (params, state)
opt_state = optimizer.init(params=model[0])
# replicates to all devices
model, opt_state = jax.device_put_replicated((model, opt_state), devices)
# Prepare checkpoint dir
now = datetime.datetime.now(datetime.timezone(datetime.timedelta(hours=9)))
now = now.strftime("%Y%m%d%H%M%S")
ckpt_dir = os.path.join("checkpoints", f"{config.env_id}_{now}")
os.makedirs(ckpt_dir, exist_ok=True)
# Initialize logging dict
iteration: int = 0
hours: float = 0.0
frames: int = 0
log = {"iteration": iteration, "hours": hours, "frames": frames}
rng_key = jax.random.PRNGKey(config.seed)
for iteration in tqdm(range(config.max_num_iters)):
if iteration % config.eval_interval == 0:
# Evaluation
rng_key, subkey = jax.random.split(rng_key)
keys = jax.random.split(subkey, num_devices)
R = evaluate(keys, model)
log.update(
{
f"eval/vs_baseline/avg_R": R.mean().item(),
f"eval/vs_baseline/win_rate": ((R == 1).sum() / R.size).item(),
f"eval/vs_baseline/draw_rate": ((R == 0).sum() / R.size).item(),
f"eval/vs_baseline/lose_rate": ((R == -1).sum() / R.size).item(),
}
)
# Store checkpoints
model_0, opt_state_0 = jax.tree_util.tree_map(lambda x: x[0], (model, opt_state))
with open(os.path.join(ckpt_dir, f"{iteration:06d}.ckpt"), "wb") as f:
dic = {
"config": config,
"rng_key": rng_key,
"model": jax.device_get(model_0),
"opt_state": jax.device_get(opt_state_0),
"iteration": iteration,
"frames": frames,
"hours": hours,
"pgx.__version__": pgx.__version__,
"env_id": env.id,
"env_version": env.version,
}
pickle.dump(dic, f)
print(log)
log = {"iteration": iteration}
st = time.time()
# Selfplay
rng_key, subkey = jax.random.split(rng_key)
keys = jax.random.split(subkey, num_devices)
data: SelfplayOutput = selfplay(model, keys)
samples: Sample = compute_loss_input(data)
# Shuffle samples and make minibatches
samples = jax.device_get(samples) # (#devices, batch, max_num_steps, ...)
frames += samples.obs.shape[0] * samples.obs.shape[1] * samples.obs.shape[2]
samples = jax.tree_util.tree_map(lambda x: x.reshape((-1, *x.shape[3:])), samples)
rng_key, subkey = jax.random.split(rng_key)
ixs = jax.random.permutation(subkey, jnp.arange(samples.obs.shape[0]))
samples = jax.tree_map(lambda x: x[ixs], samples) # shuffle
num_updates = samples.obs.shape[0] // config.training_batch_size
minibatches = jax.tree_map(
lambda x: x.reshape((num_updates, num_devices, -1) + x.shape[1:]), samples
)
# Training
policy_losses, value_losses = [], []
for i in range(num_updates):
minibatch: Sample = jax.tree_map(lambda x: x[i], minibatches)
model, opt_state, policy_loss, value_loss = train(model, opt_state, minibatch)
policy_losses.append(policy_loss.mean().item())
value_losses.append(value_loss.mean().item())
policy_loss = sum(policy_losses) / len(policy_losses)
value_loss = sum(value_losses) / len(value_losses)
et = time.time()
hours += (et - st) / 3600
log.update(
{
"train/policy_loss": policy_loss,
"train/value_loss": value_loss,
"hours": hours,
"frames": frames,
}
)