Skip to content

Commit

Permalink
Add helper method to count repetitions of a specific state.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668444060
Change-Id: I3acc21ff109c4b7ee58f43ed36795cb29c3762f1
  • Loading branch information
lanctot committed Sep 23, 2024
1 parent 77a03df commit 9435c48
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 0 deletions.
11 changes: 11 additions & 0 deletions open_spiel/games/chess/chess.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "open_spiel/games/chess/chess.h"
#include <sys/types.h>

#include <cstdint>
#include <iterator>
Expand Down Expand Up @@ -474,6 +475,16 @@ bool ChessState::IsRepetitionDraw() const {
return entry->second >= kNumRepetitionsToDraw;
}

int ChessState::NumRepetitions(const ChessState& state) const {
uint64_t state_hash_value = state.Board().HashValue();
const auto entry = repetitions_.find(state_hash_value);
if (entry == repetitions_.end()) {
return 0;
} else {
return entry->second;
}
}

absl::optional<std::vector<double>> ChessState::MaybeFinalReturns() const {
if (!Board().HasSufficientMaterial()) {
return std::vector<double>{DrawUtility(), DrawUtility()};
Expand Down
4 changes: 4 additions & 0 deletions open_spiel/games/chess/chess.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ class ChessState : public State {
// board position has already appeared twice in the history).
bool IsRepetitionDraw() const;

// Returns the number of times the specified state has appeared in the
// history.
int NumRepetitions(const ChessState& state) const;

const ChessGame* ParentGame() const {
return down_cast<const ChessGame*>(GetGame().get());
}
Expand Down
2 changes: 2 additions & 0 deletions open_spiel/python/pybind11/games_chess.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ void open_spiel::init_pyspiel_games_chess(py::module& m) {
.def("debug_string", &ChessState::DebugString)
.def("is_repetition_draw", &ChessState::IsRepetitionDraw)
.def("moves_history", py::overload_cast<>(&ChessState::MovesHistory))
// num_repetitions(state: ChessState) -> int
.def("num_repetitions", &ChessState::NumRepetitions)
.def("parse_move_to_action", &ChessState::ParseMoveToAction)
// Pickle support
.def(py::pickle(
Expand Down
1 change: 1 addition & 0 deletions open_spiel/python/tests/games_chess_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def test_state_from_fen(self):
fen_string = "8/k1P5/8/1K6/8/8/8/8 w - - 0 1"
state = game.new_initial_state(fen_string)
self.assertEqual(state.board().to_fen(), fen_string)
self.assertEqual(state.num_repetitions(state), 1)

@parameterized.parameters(
"bbqnnrkr/pppppppp/8/8/8/8/PPPPPPPP/BBQNNRKR w KQkq - 0 1",
Expand Down

0 comments on commit 9435c48

Please sign in to comment.