From 2cb017013d4509638698e39f6f8d8424be7d30d3 Mon Sep 17 00:00:00 2001 From: Ayah-Saleh <116742207+Ayah-Saleh@users.noreply.github.com> Date: Thu, 26 Sep 2024 14:37:15 -0400 Subject: [PATCH] Fix #1254: Convert pyspiel game state to dict --- open_spiel/spiel.h | 11 +++++++++++ open_spiel/tests/spiel_test.cc | 24 ++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/open_spiel/spiel.h b/open_spiel/spiel.h index cdfe2edd2f..4fd8371861 100644 --- a/open_spiel/spiel.h +++ b/open_spiel/spiel.h @@ -211,6 +211,17 @@ class State { public: virtual ~State() = default; + // convert from current state to a dictionary of array-likes + virtual std::unordered_map> ToDict() const { + SpielFatalError("ToDict is not implemented for this game state."); + } + + // method to restore the state from a dictionary of array-likes. + virtual void FromDict(const std::unordered_map>& dict) { + SpielFatalError("FromDict is not implemented for this game state."); + } + + // Derived classes must call one of these constructors. Note that a state must // be passed a pointer to the game which created it. Some methods in some // games rely on this and so it must correspond to a valid game object. diff --git a/open_spiel/tests/spiel_test.cc b/open_spiel/tests/spiel_test.cc index 45516fc0e5..aabe981fac 100644 --- a/open_spiel/tests/spiel_test.cc +++ b/open_spiel/tests/spiel_test.cc @@ -339,6 +339,28 @@ void PolicySerializationTest() { } // namespace testing } // namespace open_spiel +void TestStateToDictAndFromDict() { + // Load Tic-Tac-Toe + std::shared_ptr game = LoadGame("tic_tac_toe"); + std::unique_ptr state = game->NewInitialState(); + + // apply some moves to change the state + state->ApplyAction(0); // Player 1 places an 'X' in the top-left corner + state->ApplyAction(4); // Player 2 places an 'O' in the center + + // convert the state to a dictionary using ToDict() + std::unordered_map> state_dict = state->ToDict(); + + // create a new initial state and restore it using FromDict() + std::unique_ptr new_state = game->NewInitialState(); + new_state->FromDict(state_dict); + + // check that the original state and the restored state are equivalent + SPIEL_CHECK_EQ(state->ToString(), new_state->ToString()); + +} + + int main(int argc, char** argv) { open_spiel::testing::GeneralTests(); open_spiel::testing::KuhnTests(); @@ -349,4 +371,6 @@ int main(int argc, char** argv) { open_spiel::testing::LeducPokerDeserializeTest(); open_spiel::testing::GameParametersTest(); open_spiel::testing::PolicySerializationTest(); + // new test function + open_spiel::testing::TestStateToDictAndFromDict(); }