Skip to content

Commit

Permalink
Linted efr_test
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamesflynn1 committed Nov 15, 2023
1 parent 65d82f2 commit 3654a8d
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions open_spiel/python/algorithms/efr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,41 +29,43 @@
class EFRTest(parameterized.TestCase, absltest.TestCase):

def setUp(self):
self._KUHN_GAME = pyspiel.load_game("kuhn_poker")
self._LEDUC_GAME = pyspiel.load_game("leduc_poker")
self._KUHN_3P_GAME = pyspiel.load_game("kuhn_poker(players=3)")
self._SHERIFF_GAME = pyspiel.load_game("sheriff")
self.kuhn_game = pyspiel.load_game("kuhn_poker")
self.leduc_game = pyspiel.load_game("leduc_poker")
self.kuhn_3p_game = pyspiel.load_game("kuhn_poker(players=3)")
self.sheriff_game = pyspiel.load_game("sheriff")

self._KUHN_UNIFORM_POLICY = policy.TabularPolicy(self._KUHN_GAME)
self._LEDUC_UNIFORM_POLICY = policy.TabularPolicy(self._LEDUC_GAME)
self.kuhn_uniform_policy = policy.TabularPolicy(self.kuhn_game)
self.leduc_uniform_policy = policy.TabularPolicy(self.leduc_game)

@parameterized.parameters(["blind action", "informed action", "blind cf", "informed cf", "bps", "cfps", "csps", "tips", "bhv"])
@parameterized.parameters(["blind action", "informed action", "blind cf",
"informed cf","bps", "cfps", "csps",
"tips", "bhv"])
def test_policy_zero_is_uniform(self, deviations_name):
# We use Leduc and not Kuhn, because Leduc has illegal actions and Kuhn does
# not.
cfr_solver = efr.EFRSolver(
game=self._LEDUC_GAME,
game=self.leduc_game,
deviations_name=deviations_name
)
np.testing.assert_array_equal(
self._LEDUC_UNIFORM_POLICY.action_probability_array,
self.leduc_uniform_policy.action_probability_array,
cfr_solver.current_policy().action_probability_array)
np.testing.assert_array_equal(
self._LEDUC_UNIFORM_POLICY.action_probability_array,
self.leduc_uniform_policy.action_probability_array,
cfr_solver.average_policy().action_probability_array)

@parameterized.parameters(
["blind cf", "informed cf", "bps", "cfps", "csps", "tips", "bhv"])
def test_efr_kuhn_poker(self, deviations_name):
efr_solver = efr.EFRSolver(
game=self._KUHN_GAME,
game=self.kuhn_game,
deviations_name=deviations_name
)
for _ in range(300):
efr_solver.evaluate_and_update_policy()
average_policy = efr_solver.average_policy()
average_policy_values = expected_game_score.policy_value(
self._KUHN_GAME.new_initial_state(), [average_policy] * 2)
self.kuhn_game.new_initial_state(), [average_policy] * 2)
# 1/18 is the Nash value. See https://en.wikipedia.org/wiki/Kuhn_poker
np.testing.assert_allclose(
average_policy_values, [-1 / 18, 1 / 18], atol=1e-3)
Expand All @@ -72,7 +74,7 @@ def test_efr_kuhn_poker(self, deviations_name):
["blind cf", "informed cf", "bps", "cfps", "csps", "tips", "bhv"])
def test_efr_kuhn_poker_3p(self, deviations_name):
efr_solver = efr.EFRSolver(
game=self._KUHN_3P_GAME,
game=self.kuhn_3p_game,
deviations_name=deviations_name
)
strategies = []
Expand All @@ -84,26 +86,25 @@ def test_efr_kuhn_poker_3p(self, deviations_name):
strategies.append(policy.python_policy_to_pyspiel_policy(
efr_solver.current_policy()))
corr_dev = pyspiel.uniform_correlation_device(strategies)
cce_dist_info = pyspiel.cce_dist(self._KUHN_3P_GAME, corr_dev)
cce_dist_info = pyspiel.cce_dist(self.kuhn_3p_game, corr_dev)
corr_dist_values.append(cce_dist_info.dist_value)
self.assertLess(corr_dist_values[-1], corr_dist_values[0])


@parameterized.parameters(
["blind cf", "bps", "tips"])
def test_efr_cce_dist_sheriff(self, deviations_name):
efr_solver = efr.EFRSolver(
game=self._SHERIFF_GAME,
game=self.sheriff_game,
deviations_name=deviations_name
)
)
strategies = []
corr_dist_values = []
for _ in range(5):
efr_solver.evaluate_and_update_policy()
strategies.append(policy.python_policy_to_pyspiel_policy(
efr_solver.current_policy()))
efr_solver.current_policy()))
corr_dev = pyspiel.uniform_correlation_device(strategies)
cce_dist_info = pyspiel.cce_dist(self._SHERIFF_GAME, corr_dev)
cce_dist_info = pyspiel.cce_dist(self.sheriff_game, corr_dev)
corr_dist_values.append(cce_dist_info.dist_value)
self.assertLess(corr_dist_values[-1], corr_dist_values[0])
if __name__ == "__main__":
Expand Down

0 comments on commit 3654a8d

Please sign in to comment.