-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtest.py
50 lines (39 loc) · 1.64 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from __future__ import print_function
import os
import numpy as np
from absl import app
from absl import flags
import tensorflow as tf
from env import Environment
from game import CFRRL_Game
from model import Network
from config import get_config
FLAGS = flags.FLAGS
flags.DEFINE_string('ckpt', '', 'apply a specific checkpoint')
flags.DEFINE_boolean('eval_delay', False, 'evaluate delay or not')
def sim(config, network, game):
for tm_idx in game.tm_indexes:
state = game.get_state(tm_idx)
if config.method == 'actor_critic':
policy = network.actor_predict(np.expand_dims(state, 0)).numpy()[0]
elif config.method == 'pure_policy':
policy = network.policy_predict(np.expand_dims(state, 0)).numpy()[0]
actions = policy.argsort()[-game.max_moves:]
game.evaluate(tm_idx, actions, eval_delay=FLAGS.eval_delay)
def main(_):
#Using cpu for testing
tf.config.experimental.set_visible_devices([], 'GPU')
tf.get_logger().setLevel('INFO')
config = get_config(FLAGS) or FLAGS
env = Environment(config, is_training=False)
game = CFRRL_Game(config, env)
network = Network(config, game.state_dims, game.action_dim, game.max_moves)
step = network.restore_ckpt(FLAGS.ckpt)
if config.method == 'actor_critic':
learning_rate = network.lr_schedule(network.actor_optimizer.iterations.numpy()).numpy()
elif config.method == 'pure_policy':
learning_rate = network.lr_schedule(network.optimizer.iterations.numpy()).numpy()
print('\nstep %d, learning rate: %f\n'% (step, learning_rate))
sim(config, network, game)
if __name__ == '__main__':
app.run(main)