Skip to content

Commit

Permalink
SL用に状態から過去の意志決定の履歴を生成する (#991)
Browse files Browse the repository at this point in the history
* init pr

* init method

* extract SetInitState

* extract EventsToActions

* extract UpdateByEvents

* Apply formatter

* implement past_decisions

* add test for past decisions

* Apply formatter

* make static

* implement mjx::State::past_decisions

* add test for mjx::State::past_decisions

* Apply formatter

Co-authored-by: GitHub Actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sotetsuk and github-actions[bot] authored Oct 19, 2021
1 parent 1b890f6 commit 225694d
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 130 deletions.
289 changes: 159 additions & 130 deletions include/mjx/internal/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/util/message_differencer.h>

#include <queue>

#include "mjx/internal/utils.h"

namespace mjx::internal {
Expand Down Expand Up @@ -260,134 +258,9 @@ mjxproto::State State::LoadJson(const std::string &json_str) {
State::State(const std::string &json_str) : State(LoadJson(json_str)) {}

State::State(const mjxproto::State &state) {
// Set player ids
state_.mutable_public_observation()->mutable_player_ids()->CopyFrom(
state.public_observation().player_ids());
// Set scores
state_.mutable_public_observation()->mutable_init_score()->CopyFrom(
state.public_observation().init_score());
curr_score_.CopyFrom(state.public_observation().init_score());
// Set walls
auto wall_tiles = std::vector<Tile>();
for (auto tile_id : state.hidden_state().wall())
wall_tiles.emplace_back(Tile(tile_id));
wall_ = Wall(round(), wall_tiles);
state_.mutable_hidden_state()->mutable_wall()->CopyFrom(
state.hidden_state().wall());
// Set seed
state_.mutable_hidden_state()->set_game_seed(
state.hidden_state().game_seed());
// Set dora
state_.mutable_public_observation()->add_dora_indicators(
wall_.dora_indicators().front().Id());
state_.mutable_hidden_state()->add_ura_dora_indicators(
wall_.ura_dora_indicators().front().Id());
// Set init hands
for (int i = 0; i < 4; ++i) {
players_[i] =
Player{state_.public_observation().player_ids(i), AbsolutePos(i),
Hand(wall_.initial_hand_tiles(AbsolutePos(i)))};
state_.mutable_private_observations()->Add();
state_.mutable_private_observations(i)->set_who(i);
for (auto t : wall_.initial_hand_tiles(AbsolutePos(i))) {
state_.mutable_private_observations(i)
->mutable_init_hand()
->mutable_closed_tiles()
->Add(t.Id());
}
// set game_id
state_.mutable_public_observation()->set_game_id(
state.public_observation().game_id());
}

// Initial draw from dealer
Draw(dealer());

// sync curr_hand
for (int i = 0; i < 4; ++i) SyncCurrHand(AbsolutePos(i));

// Update by events
std::queue<mjxproto::Action> actions;
int last_ron_target = -1;
int last_ron_target_tile = -1;
for (const auto &event : state.public_observation().events()) {
if (Any(event.type(),
{mjxproto::EVENT_TYPE_DISCARD, mjxproto::EVENT_TYPE_TSUMOGIRI,
mjxproto::EVENT_TYPE_ADDED_KAN})) {
last_ron_target = event.who();
last_ron_target_tile = event.tile();
}
if (event.type() == mjxproto::EVENT_TYPE_ABORTIVE_DRAW_THREE_RONS) {
assert(last_ron_target != -1);
assert(last_ron_target_tile != -1);
for (int i = 0; i < 4; ++i) {
if (i == last_ron_target) continue;
mjxproto::Action ron =
Action::CreateRon(AbsolutePos(i), Tile(last_ron_target_tile),
state_.public_observation().game_id());
actions.push(ron);
}
continue;
}
std::optional<mjxproto::Action> action = Action::FromEvent(event);
if (action) actions.push(action.value());
}

while (state.public_observation().events_size() >
state_.public_observation().events_size()) {
auto observations = CreateObservations();
std::unordered_set<PlayerId> is_action_set;
std::vector<mjxproto::Action> action_candidates;

// set action from next_action
while (true) {
if (actions.empty()) break;
mjxproto::Action next_action = actions.front();
bool should_continue = false;
for (const auto &[player_id, obs] : observations) {
if (is_action_set.count(player_id)) continue;
std::vector<mjxproto::Action> legal_actions = obs.legal_actions();
bool has_next_action =
std::count_if(legal_actions.begin(), legal_actions.end(),
[&next_action](const mjxproto::Action &x) {
return Action::Equal(x, next_action);
});
if (has_next_action) {
action_candidates.push_back(next_action);
is_action_set.insert(player_id);
actions.pop();
should_continue = true;
break;
}
}
if (!should_continue) break;
}

// set no actions
for (const auto &[player_id, obs] : observations) {
if (is_action_set.count(player_id)) continue;
std::vector<mjxproto::Action> legal_actions = obs.legal_actions();
auto itr = std::find_if(legal_actions.begin(), legal_actions.end(),
[](const mjxproto::Action &x) {
return x.type() == mjxproto::ACTION_TYPE_NO;
});
Assert(itr != legal_actions.end(),
"Legal actions should have No Action.\nExpected:\n" +
ProtoToJson(state) + "\nActual:\n" + ToJson());
auto action_no = *itr;
action_candidates.push_back(action_no);
}

Assert(action_candidates.size() == observations.size(),
"Expected:\n" + ProtoToJson(state) + "\nActual:\n" + ToJson() +
"action_candidates.size():\n" +
std::to_string(action_candidates.size()) +
"\nobservations.size():\n" +
std::to_string(observations.size()));

Update(std::move(action_candidates));
}

SetInitState(state, *this);
std::queue<mjxproto::Action> actions = EventsToActions(state);
UpdateByActions(state, actions, *this);
Assert(google::protobuf::util::MessageDifferencer::Equals(state, proto()),
"Expected:\n" + ProtoToJson(state) + "\nActual:\n" + ToJson());
}
Expand Down Expand Up @@ -1752,4 +1625,160 @@ std::string State::ProtoToJson(const mjxproto::State &proto) {
Assert(status.ok());
return serialized;
}

std::vector<std::pair<mjxproto::Observation, mjxproto::Action>>
State::GeneratePastDecisions(const mjxproto::State &proto) noexcept {
State st;
SetInitState(proto, st);
std::queue<mjxproto::Action> actions = EventsToActions(proto);
auto decisions = UpdateByActions(proto, actions, st);
Assert(google::protobuf::util::MessageDifferencer::Equals(proto, st.proto()),
"Expected:\n" + ProtoToJson(proto) + "\nActual:\n" + st.ToJson());
return decisions;
}

void State::SetInitState(const mjxproto::State &proto, State &state) {
// Set player ids
state.state_.mutable_public_observation()->mutable_player_ids()->CopyFrom(
proto.public_observation().player_ids());
// Set scores
state.state_.mutable_public_observation()->mutable_init_score()->CopyFrom(
proto.public_observation().init_score());
state.curr_score_.CopyFrom(proto.public_observation().init_score());
// Set walls
auto wall_tiles = std::vector<Tile>();
for (auto tile_id : proto.hidden_state().wall())
wall_tiles.emplace_back(Tile(tile_id));
state.wall_ = Wall(state.round(), wall_tiles);
state.state_.mutable_hidden_state()->mutable_wall()->CopyFrom(
proto.hidden_state().wall());
// Set seed
state.state_.mutable_hidden_state()->set_game_seed(
proto.hidden_state().game_seed());
// Set dora
state.state_.mutable_public_observation()->add_dora_indicators(
state.wall_.dora_indicators().front().Id());
state.state_.mutable_hidden_state()->add_ura_dora_indicators(
state.wall_.ura_dora_indicators().front().Id());
// Set init hands
for (int i = 0; i < 4; ++i) {
state.players_[i] =
Player{state.state_.public_observation().player_ids(i), AbsolutePos(i),
Hand(state.wall_.initial_hand_tiles(AbsolutePos(i)))};
state.state_.mutable_private_observations()->Add();
state.state_.mutable_private_observations(i)->set_who(i);
for (auto t : state.wall_.initial_hand_tiles(AbsolutePos(i))) {
state.state_.mutable_private_observations(i)
->mutable_init_hand()
->mutable_closed_tiles()
->Add(t.Id());
}
// set game_id
state.state_.mutable_public_observation()->set_game_id(
proto.public_observation().game_id());
}

// Initial draw from dealer
state.Draw(state.dealer());

// sync curr_hand
for (int i = 0; i < 4; ++i) state.SyncCurrHand(AbsolutePos(i));
}

std::queue<mjxproto::Action> State::EventsToActions(
const mjxproto::State &proto) {
std::queue<mjxproto::Action> actions;
int last_ron_target = -1;
int last_ron_target_tile = -1;
for (const auto &event : proto.public_observation().events()) {
if (Any(event.type(),
{mjxproto::EVENT_TYPE_DISCARD, mjxproto::EVENT_TYPE_TSUMOGIRI,
mjxproto::EVENT_TYPE_ADDED_KAN})) {
last_ron_target = event.who();
last_ron_target_tile = event.tile();
}
if (event.type() == mjxproto::EVENT_TYPE_ABORTIVE_DRAW_THREE_RONS) {
assert(last_ron_target != -1);
assert(last_ron_target_tile != -1);
for (int i = 0; i < 4; ++i) {
if (i == last_ron_target) continue;
mjxproto::Action ron =
Action::CreateRon(AbsolutePos(i), Tile(last_ron_target_tile),
proto.public_observation().game_id());
actions.push(ron);
}
continue;
}
std::optional<mjxproto::Action> action = Action::FromEvent(event);
if (action) actions.push(action.value());
}
return actions;
}

std::vector<std::pair<mjxproto::Observation, mjxproto::Action>>
State::UpdateByActions(const mjxproto::State &proto,
std::queue<mjxproto::Action> &actions, State &state) {
std::vector<std::pair<mjxproto::Observation, mjxproto::Action>> hist;

while (proto.public_observation().events_size() >
state.state_.public_observation().events_size()) {
auto observations = state.CreateObservations();
std::unordered_map<PlayerId, mjxproto::Action> action_candidates;

// set action from next_action
while (true) {
if (actions.empty()) break;
mjxproto::Action next_action = actions.front();
bool should_continue = false;
for (const auto &[player_id, obs] : observations) {
if (action_candidates.count(player_id)) continue;
std::vector<mjxproto::Action> legal_actions = obs.legal_actions();
bool has_next_action =
std::count_if(legal_actions.begin(), legal_actions.end(),
[&next_action](const mjxproto::Action &x) {
return Action::Equal(x, next_action);
});
if (has_next_action) {
action_candidates[player_id] = next_action;
actions.pop();
should_continue = true;
break;
}
}
if (!should_continue) break;
}

// set no actions
for (const auto &[player_id, obs] : observations) {
if (action_candidates.count(player_id)) continue;
std::vector<mjxproto::Action> legal_actions = obs.legal_actions();
auto itr = std::find_if(legal_actions.begin(), legal_actions.end(),
[](const mjxproto::Action &x) {
return x.type() == mjxproto::ACTION_TYPE_NO;
});
Assert(itr != legal_actions.end(),
"Legal actions should have No Action.\nExpected:\n" +
ProtoToJson(proto) + "\nActual:\n" + state.ToJson());
auto action_no = *itr;
action_candidates[player_id] = action_no;
}

Assert(action_candidates.size() == observations.size(),
"Expected:\n" + ProtoToJson(proto) + "\nActual:\n" + state.ToJson() +
"action_candidates.size():\n" +
std::to_string(action_candidates.size()) +
"\nobservations.size():\n" +
std::to_string(observations.size()));

std::vector<mjxproto::Action> action_vec;
for (const auto &[player_id, obs] : observations) {
auto action = action_candidates[player_id];
hist.emplace_back(obs.proto(), action);
action_vec.push_back(action);
}
state.Update(std::move(action_vec));
}

return hist;
}
} // namespace mjx::internal
14 changes: 14 additions & 0 deletions include/mjx/internal/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <boost/uuid/uuid.hpp>
#include <boost/uuid/uuid_generators.hpp>
#include <boost/uuid/uuid_io.hpp>
#include <queue>
#include <random>
#include <string>
#include <utility>
Expand Down Expand Up @@ -52,6 +53,8 @@ class State {
GameResult result() const;
State::ScoreInfo Next() const;
mjxproto::Observation observation(const PlayerId& player_id) const;
static std::vector<std::pair<mjxproto::Observation, mjxproto::Action>>
GeneratePastDecisions(const mjxproto::State& proto) noexcept;

static std::vector<PlayerId> ShufflePlayerIds(
std::uint64_t game_seed, const std::vector<PlayerId>& player_ids);
Expand Down Expand Up @@ -163,6 +166,17 @@ class State {

// protoのcurr_handを同期する。
void SyncCurrHand(AbsolutePos who);

// protobufから初期状態(親のツモの直後)を抽出して、stateへセットする
static void SetInitState(const mjxproto::State& proto, State& state);
// protoのEvent系列で見えているイベントをAction系列へ変換して返す(Noは含まない。三家和了はロンが3つ連なる)
static std::queue<mjxproto::Action> EventsToActions(
const mjxproto::State& proto);
// stateがprotoと同じにものになるように、actionsからactionをpopしながらstateを更新する(actionsにはNoがないので、それらを補完する)
// 結果として現れたObservation, Actionのペアが返される
static std::vector<std::pair<mjxproto::Observation, mjxproto::Action>>
UpdateByActions(const mjxproto::State& proto,
std::queue<mjxproto::Action>& actions, State& state);
};
} // namespace mjx::internal

Expand Down
12 changes: 12 additions & 0 deletions include/mjx/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

#include <utility>

#include "mjx/internal/state.h"

namespace mjx {
mjx::State::State(mjxproto::State proto) : proto_(std::move(proto)) {}

Expand All @@ -31,4 +33,14 @@ bool State::operator==(const State& other) const noexcept {
bool State::operator!=(const State& other) const noexcept {
return !(*this == other);
}

std::vector<std::pair<Observation, Action>> State::past_decisions()
const noexcept {
std::vector<std::pair<Observation, Action>> decisions;
auto proto_decisions = internal::State::GeneratePastDecisions(proto());
for (const auto& [obs, action] : proto_decisions) {
decisions.emplace_back(Observation(obs), Action(action));
}
return decisions;
}
} // namespace mjx
3 changes: 3 additions & 0 deletions include/mjx/state.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#ifndef MJX_PROJECT_STATE_H
#define MJX_PROJECT_STATE_H

#include "mjx/action.h"
#include "mjx/internal/mjx.grpc.pb.h"
#include "mjx/observation.h"

namespace mjx {
using PlayerId = std::string; // identical over different games
Expand All @@ -19,6 +21,7 @@ class State {

// accessors
const mjxproto::State& proto() const noexcept;
std::vector<std::pair<Observation, Action>> past_decisions() const noexcept;

private:
mjxproto::State proto_{};
Expand Down
Loading

0 comments on commit 225694d

Please sign in to comment.