diff --git a/include/mjx/internal/state.cpp b/include/mjx/internal/state.cpp index 498ae9cde..9a5b7d365 100644 --- a/include/mjx/internal/state.cpp +++ b/include/mjx/internal/state.cpp @@ -3,8 +3,6 @@ #include #include -#include - #include "mjx/internal/utils.h" namespace mjx::internal { @@ -260,134 +258,9 @@ mjxproto::State State::LoadJson(const std::string &json_str) { State::State(const std::string &json_str) : State(LoadJson(json_str)) {} State::State(const mjxproto::State &state) { - // Set player ids - state_.mutable_public_observation()->mutable_player_ids()->CopyFrom( - state.public_observation().player_ids()); - // Set scores - state_.mutable_public_observation()->mutable_init_score()->CopyFrom( - state.public_observation().init_score()); - curr_score_.CopyFrom(state.public_observation().init_score()); - // Set walls - auto wall_tiles = std::vector(); - for (auto tile_id : state.hidden_state().wall()) - wall_tiles.emplace_back(Tile(tile_id)); - wall_ = Wall(round(), wall_tiles); - state_.mutable_hidden_state()->mutable_wall()->CopyFrom( - state.hidden_state().wall()); - // Set seed - state_.mutable_hidden_state()->set_game_seed( - state.hidden_state().game_seed()); - // Set dora - state_.mutable_public_observation()->add_dora_indicators( - wall_.dora_indicators().front().Id()); - state_.mutable_hidden_state()->add_ura_dora_indicators( - wall_.ura_dora_indicators().front().Id()); - // Set init hands - for (int i = 0; i < 4; ++i) { - players_[i] = - Player{state_.public_observation().player_ids(i), AbsolutePos(i), - Hand(wall_.initial_hand_tiles(AbsolutePos(i)))}; - state_.mutable_private_observations()->Add(); - state_.mutable_private_observations(i)->set_who(i); - for (auto t : wall_.initial_hand_tiles(AbsolutePos(i))) { - state_.mutable_private_observations(i) - ->mutable_init_hand() - ->mutable_closed_tiles() - ->Add(t.Id()); - } - // set game_id - state_.mutable_public_observation()->set_game_id( - state.public_observation().game_id()); - } - - // Initial draw from dealer - Draw(dealer()); - - // sync curr_hand - for (int i = 0; i < 4; ++i) SyncCurrHand(AbsolutePos(i)); - - // Update by events - std::queue actions; - int last_ron_target = -1; - int last_ron_target_tile = -1; - for (const auto &event : state.public_observation().events()) { - if (Any(event.type(), - {mjxproto::EVENT_TYPE_DISCARD, mjxproto::EVENT_TYPE_TSUMOGIRI, - mjxproto::EVENT_TYPE_ADDED_KAN})) { - last_ron_target = event.who(); - last_ron_target_tile = event.tile(); - } - if (event.type() == mjxproto::EVENT_TYPE_ABORTIVE_DRAW_THREE_RONS) { - assert(last_ron_target != -1); - assert(last_ron_target_tile != -1); - for (int i = 0; i < 4; ++i) { - if (i == last_ron_target) continue; - mjxproto::Action ron = - Action::CreateRon(AbsolutePos(i), Tile(last_ron_target_tile), - state_.public_observation().game_id()); - actions.push(ron); - } - continue; - } - std::optional action = Action::FromEvent(event); - if (action) actions.push(action.value()); - } - - while (state.public_observation().events_size() > - state_.public_observation().events_size()) { - auto observations = CreateObservations(); - std::unordered_set is_action_set; - std::vector action_candidates; - - // set action from next_action - while (true) { - if (actions.empty()) break; - mjxproto::Action next_action = actions.front(); - bool should_continue = false; - for (const auto &[player_id, obs] : observations) { - if (is_action_set.count(player_id)) continue; - std::vector legal_actions = obs.legal_actions(); - bool has_next_action = - std::count_if(legal_actions.begin(), legal_actions.end(), - [&next_action](const mjxproto::Action &x) { - return Action::Equal(x, next_action); - }); - if (has_next_action) { - action_candidates.push_back(next_action); - is_action_set.insert(player_id); - actions.pop(); - should_continue = true; - break; - } - } - if (!should_continue) break; - } - - // set no actions - for (const auto &[player_id, obs] : observations) { - if (is_action_set.count(player_id)) continue; - std::vector legal_actions = obs.legal_actions(); - auto itr = std::find_if(legal_actions.begin(), legal_actions.end(), - [](const mjxproto::Action &x) { - return x.type() == mjxproto::ACTION_TYPE_NO; - }); - Assert(itr != legal_actions.end(), - "Legal actions should have No Action.\nExpected:\n" + - ProtoToJson(state) + "\nActual:\n" + ToJson()); - auto action_no = *itr; - action_candidates.push_back(action_no); - } - - Assert(action_candidates.size() == observations.size(), - "Expected:\n" + ProtoToJson(state) + "\nActual:\n" + ToJson() + - "action_candidates.size():\n" + - std::to_string(action_candidates.size()) + - "\nobservations.size():\n" + - std::to_string(observations.size())); - - Update(std::move(action_candidates)); - } - + SetInitState(state, *this); + std::queue actions = EventsToActions(state); + UpdateByActions(state, actions, *this); Assert(google::protobuf::util::MessageDifferencer::Equals(state, proto()), "Expected:\n" + ProtoToJson(state) + "\nActual:\n" + ToJson()); } @@ -1752,4 +1625,160 @@ std::string State::ProtoToJson(const mjxproto::State &proto) { Assert(status.ok()); return serialized; } + +std::vector> +State::GeneratePastDecisions(const mjxproto::State &proto) noexcept { + State st; + SetInitState(proto, st); + std::queue actions = EventsToActions(proto); + auto decisions = UpdateByActions(proto, actions, st); + Assert(google::protobuf::util::MessageDifferencer::Equals(proto, st.proto()), + "Expected:\n" + ProtoToJson(proto) + "\nActual:\n" + st.ToJson()); + return decisions; +} + +void State::SetInitState(const mjxproto::State &proto, State &state) { + // Set player ids + state.state_.mutable_public_observation()->mutable_player_ids()->CopyFrom( + proto.public_observation().player_ids()); + // Set scores + state.state_.mutable_public_observation()->mutable_init_score()->CopyFrom( + proto.public_observation().init_score()); + state.curr_score_.CopyFrom(proto.public_observation().init_score()); + // Set walls + auto wall_tiles = std::vector(); + for (auto tile_id : proto.hidden_state().wall()) + wall_tiles.emplace_back(Tile(tile_id)); + state.wall_ = Wall(state.round(), wall_tiles); + state.state_.mutable_hidden_state()->mutable_wall()->CopyFrom( + proto.hidden_state().wall()); + // Set seed + state.state_.mutable_hidden_state()->set_game_seed( + proto.hidden_state().game_seed()); + // Set dora + state.state_.mutable_public_observation()->add_dora_indicators( + state.wall_.dora_indicators().front().Id()); + state.state_.mutable_hidden_state()->add_ura_dora_indicators( + state.wall_.ura_dora_indicators().front().Id()); + // Set init hands + for (int i = 0; i < 4; ++i) { + state.players_[i] = + Player{state.state_.public_observation().player_ids(i), AbsolutePos(i), + Hand(state.wall_.initial_hand_tiles(AbsolutePos(i)))}; + state.state_.mutable_private_observations()->Add(); + state.state_.mutable_private_observations(i)->set_who(i); + for (auto t : state.wall_.initial_hand_tiles(AbsolutePos(i))) { + state.state_.mutable_private_observations(i) + ->mutable_init_hand() + ->mutable_closed_tiles() + ->Add(t.Id()); + } + // set game_id + state.state_.mutable_public_observation()->set_game_id( + proto.public_observation().game_id()); + } + + // Initial draw from dealer + state.Draw(state.dealer()); + + // sync curr_hand + for (int i = 0; i < 4; ++i) state.SyncCurrHand(AbsolutePos(i)); +} + +std::queue State::EventsToActions( + const mjxproto::State &proto) { + std::queue actions; + int last_ron_target = -1; + int last_ron_target_tile = -1; + for (const auto &event : proto.public_observation().events()) { + if (Any(event.type(), + {mjxproto::EVENT_TYPE_DISCARD, mjxproto::EVENT_TYPE_TSUMOGIRI, + mjxproto::EVENT_TYPE_ADDED_KAN})) { + last_ron_target = event.who(); + last_ron_target_tile = event.tile(); + } + if (event.type() == mjxproto::EVENT_TYPE_ABORTIVE_DRAW_THREE_RONS) { + assert(last_ron_target != -1); + assert(last_ron_target_tile != -1); + for (int i = 0; i < 4; ++i) { + if (i == last_ron_target) continue; + mjxproto::Action ron = + Action::CreateRon(AbsolutePos(i), Tile(last_ron_target_tile), + proto.public_observation().game_id()); + actions.push(ron); + } + continue; + } + std::optional action = Action::FromEvent(event); + if (action) actions.push(action.value()); + } + return actions; +} + +std::vector> +State::UpdateByActions(const mjxproto::State &proto, + std::queue &actions, State &state) { + std::vector> hist; + + while (proto.public_observation().events_size() > + state.state_.public_observation().events_size()) { + auto observations = state.CreateObservations(); + std::unordered_map action_candidates; + + // set action from next_action + while (true) { + if (actions.empty()) break; + mjxproto::Action next_action = actions.front(); + bool should_continue = false; + for (const auto &[player_id, obs] : observations) { + if (action_candidates.count(player_id)) continue; + std::vector legal_actions = obs.legal_actions(); + bool has_next_action = + std::count_if(legal_actions.begin(), legal_actions.end(), + [&next_action](const mjxproto::Action &x) { + return Action::Equal(x, next_action); + }); + if (has_next_action) { + action_candidates[player_id] = next_action; + actions.pop(); + should_continue = true; + break; + } + } + if (!should_continue) break; + } + + // set no actions + for (const auto &[player_id, obs] : observations) { + if (action_candidates.count(player_id)) continue; + std::vector legal_actions = obs.legal_actions(); + auto itr = std::find_if(legal_actions.begin(), legal_actions.end(), + [](const mjxproto::Action &x) { + return x.type() == mjxproto::ACTION_TYPE_NO; + }); + Assert(itr != legal_actions.end(), + "Legal actions should have No Action.\nExpected:\n" + + ProtoToJson(proto) + "\nActual:\n" + state.ToJson()); + auto action_no = *itr; + action_candidates[player_id] = action_no; + } + + Assert(action_candidates.size() == observations.size(), + "Expected:\n" + ProtoToJson(proto) + "\nActual:\n" + state.ToJson() + + "action_candidates.size():\n" + + std::to_string(action_candidates.size()) + + "\nobservations.size():\n" + + std::to_string(observations.size())); + + std::vector action_vec; + for (const auto &[player_id, obs] : observations) { + auto action = action_candidates[player_id]; + hist.emplace_back(obs.proto(), action); + action_vec.push_back(action); + } + state.Update(std::move(action_vec)); + } + + return hist; +} } // namespace mjx::internal diff --git a/include/mjx/internal/state.h b/include/mjx/internal/state.h index 2b41fd84d..d7b20f776 100644 --- a/include/mjx/internal/state.h +++ b/include/mjx/internal/state.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -52,6 +53,8 @@ class State { GameResult result() const; State::ScoreInfo Next() const; mjxproto::Observation observation(const PlayerId& player_id) const; + static std::vector> + GeneratePastDecisions(const mjxproto::State& proto) noexcept; static std::vector ShufflePlayerIds( std::uint64_t game_seed, const std::vector& player_ids); @@ -163,6 +166,17 @@ class State { // protoのcurr_handを同期する。 void SyncCurrHand(AbsolutePos who); + + // protobufから初期状態(親のツモの直後)を抽出して、stateへセットする + static void SetInitState(const mjxproto::State& proto, State& state); + // protoのEvent系列で見えているイベントをAction系列へ変換して返す(Noは含まない。三家和了はロンが3つ連なる) + static std::queue EventsToActions( + const mjxproto::State& proto); + // stateがprotoと同じにものになるように、actionsからactionをpopしながらstateを更新する(actionsにはNoがないので、それらを補完する) + // 結果として現れたObservation, Actionのペアが返される + static std::vector> + UpdateByActions(const mjxproto::State& proto, + std::queue& actions, State& state); }; } // namespace mjx::internal diff --git a/include/mjx/state.cpp b/include/mjx/state.cpp index 15cfcc3fb..5c64d5ec3 100644 --- a/include/mjx/state.cpp +++ b/include/mjx/state.cpp @@ -5,6 +5,8 @@ #include +#include "mjx/internal/state.h" + namespace mjx { mjx::State::State(mjxproto::State proto) : proto_(std::move(proto)) {} @@ -31,4 +33,14 @@ bool State::operator==(const State& other) const noexcept { bool State::operator!=(const State& other) const noexcept { return !(*this == other); } + +std::vector> State::past_decisions() + const noexcept { + std::vector> decisions; + auto proto_decisions = internal::State::GeneratePastDecisions(proto()); + for (const auto& [obs, action] : proto_decisions) { + decisions.emplace_back(Observation(obs), Action(action)); + } + return decisions; +} } // namespace mjx diff --git a/include/mjx/state.h b/include/mjx/state.h index 04427bd8e..77e1b11b7 100644 --- a/include/mjx/state.h +++ b/include/mjx/state.h @@ -1,7 +1,9 @@ #ifndef MJX_PROJECT_STATE_H #define MJX_PROJECT_STATE_H +#include "mjx/action.h" #include "mjx/internal/mjx.grpc.pb.h" +#include "mjx/observation.h" namespace mjx { using PlayerId = std::string; // identical over different games @@ -19,6 +21,7 @@ class State { // accessors const mjxproto::State& proto() const noexcept; + std::vector> past_decisions() const noexcept; private: mjxproto::State proto_{}; diff --git a/tests_cpp/internal_state_test.cpp b/tests_cpp/internal_state_test.cpp index e35a969d8..1b19667e3 100644 --- a/tests_cpp/internal_state_test.cpp +++ b/tests_cpp/internal_state_test.cpp @@ -1145,3 +1145,19 @@ TEST(internal_state, GameId) { EXPECT_EQ(state1.proto().public_observation().game_id(), state3.proto().public_observation().game_id()); } + +TEST(internal_state, GeneratePastDecisions) { + auto json = GetLastJsonLine("upd-aft-ron3.json"); + State state(json); + auto past_decisions = state.GeneratePastDecisions(state.proto()); + // for (const auto& [obs, action]: GeneratePastDecisions) { + // std::cerr << Observation(obs).ToJson() << "\t" << + // Action::ProtoToJson(action) << std::endl; + // } + EXPECT_EQ(std::count_if(past_decisions.begin(), past_decisions.end(), + [](const auto &x) { + mjxproto::Action action = x.second; + return action.type() == mjxproto::ACTION_TYPE_RON; + }), + 3); +} diff --git a/tests_cpp/state_test.cpp b/tests_cpp/state_test.cpp index 9bc5bf7ec..6cb1e1b27 100644 --- a/tests_cpp/state_test.cpp +++ b/tests_cpp/state_test.cpp @@ -30,3 +30,10 @@ TEST(state, op) { EXPECT_EQ(state1, state2); EXPECT_NE(state1, mjx::State()); } + +TEST(state, past_decisions) { + auto state = mjx::State(sample_json); + auto past_decisions = state.past_decisions(); + EXPECT_EQ(past_decisions.size(), + 87); // TODO: 87が正しい値かどうかは確認していない +}