From 3654a8dca74c86821fb9537a21f731c455a0230e Mon Sep 17 00:00:00 2001 From: jameswflynn Date: Wed, 15 Nov 2023 01:20:59 +0000 Subject: [PATCH] Linted efr_test --- open_spiel/python/algorithms/efr_test.py | 39 ++++++++++++------------ 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/open_spiel/python/algorithms/efr_test.py b/open_spiel/python/algorithms/efr_test.py index 2f2cbb4296..9ef99bd455 100644 --- a/open_spiel/python/algorithms/efr_test.py +++ b/open_spiel/python/algorithms/efr_test.py @@ -29,41 +29,43 @@ class EFRTest(parameterized.TestCase, absltest.TestCase): def setUp(self): - self._KUHN_GAME = pyspiel.load_game("kuhn_poker") - self._LEDUC_GAME = pyspiel.load_game("leduc_poker") - self._KUHN_3P_GAME = pyspiel.load_game("kuhn_poker(players=3)") - self._SHERIFF_GAME = pyspiel.load_game("sheriff") + self.kuhn_game = pyspiel.load_game("kuhn_poker") + self.leduc_game = pyspiel.load_game("leduc_poker") + self.kuhn_3p_game = pyspiel.load_game("kuhn_poker(players=3)") + self.sheriff_game = pyspiel.load_game("sheriff") - self._KUHN_UNIFORM_POLICY = policy.TabularPolicy(self._KUHN_GAME) - self._LEDUC_UNIFORM_POLICY = policy.TabularPolicy(self._LEDUC_GAME) + self.kuhn_uniform_policy = policy.TabularPolicy(self.kuhn_game) + self.leduc_uniform_policy = policy.TabularPolicy(self.leduc_game) - @parameterized.parameters(["blind action", "informed action", "blind cf", "informed cf", "bps", "cfps", "csps", "tips", "bhv"]) + @parameterized.parameters(["blind action", "informed action", "blind cf", + "informed cf","bps", "cfps", "csps", + "tips", "bhv"]) def test_policy_zero_is_uniform(self, deviations_name): # We use Leduc and not Kuhn, because Leduc has illegal actions and Kuhn does # not. cfr_solver = efr.EFRSolver( - game=self._LEDUC_GAME, + game=self.leduc_game, deviations_name=deviations_name ) np.testing.assert_array_equal( - self._LEDUC_UNIFORM_POLICY.action_probability_array, + self.leduc_uniform_policy.action_probability_array, cfr_solver.current_policy().action_probability_array) np.testing.assert_array_equal( - self._LEDUC_UNIFORM_POLICY.action_probability_array, + self.leduc_uniform_policy.action_probability_array, cfr_solver.average_policy().action_probability_array) @parameterized.parameters( ["blind cf", "informed cf", "bps", "cfps", "csps", "tips", "bhv"]) def test_efr_kuhn_poker(self, deviations_name): efr_solver = efr.EFRSolver( - game=self._KUHN_GAME, + game=self.kuhn_game, deviations_name=deviations_name ) for _ in range(300): efr_solver.evaluate_and_update_policy() average_policy = efr_solver.average_policy() average_policy_values = expected_game_score.policy_value( - self._KUHN_GAME.new_initial_state(), [average_policy] * 2) + self.kuhn_game.new_initial_state(), [average_policy] * 2) # 1/18 is the Nash value. See https://en.wikipedia.org/wiki/Kuhn_poker np.testing.assert_allclose( average_policy_values, [-1 / 18, 1 / 18], atol=1e-3) @@ -72,7 +74,7 @@ def test_efr_kuhn_poker(self, deviations_name): ["blind cf", "informed cf", "bps", "cfps", "csps", "tips", "bhv"]) def test_efr_kuhn_poker_3p(self, deviations_name): efr_solver = efr.EFRSolver( - game=self._KUHN_3P_GAME, + game=self.kuhn_3p_game, deviations_name=deviations_name ) strategies = [] @@ -84,26 +86,25 @@ def test_efr_kuhn_poker_3p(self, deviations_name): strategies.append(policy.python_policy_to_pyspiel_policy( efr_solver.current_policy())) corr_dev = pyspiel.uniform_correlation_device(strategies) - cce_dist_info = pyspiel.cce_dist(self._KUHN_3P_GAME, corr_dev) + cce_dist_info = pyspiel.cce_dist(self.kuhn_3p_game, corr_dev) corr_dist_values.append(cce_dist_info.dist_value) self.assertLess(corr_dist_values[-1], corr_dist_values[0]) - @parameterized.parameters( ["blind cf", "bps", "tips"]) def test_efr_cce_dist_sheriff(self, deviations_name): efr_solver = efr.EFRSolver( - game=self._SHERIFF_GAME, + game=self.sheriff_game, deviations_name=deviations_name - ) + ) strategies = [] corr_dist_values = [] for _ in range(5): efr_solver.evaluate_and_update_policy() strategies.append(policy.python_policy_to_pyspiel_policy( - efr_solver.current_policy())) + efr_solver.current_policy())) corr_dev = pyspiel.uniform_correlation_device(strategies) - cce_dist_info = pyspiel.cce_dist(self._SHERIFF_GAME, corr_dev) + cce_dist_info = pyspiel.cce_dist(self.sheriff_game, corr_dev) corr_dist_values.append(cce_dist_info.dist_value) self.assertLess(corr_dist_values[-1], corr_dist_values[0]) if __name__ == "__main__":