Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update han22-v0 features #1123

Merged
merged 6 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions include/mjx/internal/observation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,222 @@ std::vector<std::vector<int>> Observation::ToFeaturesSmallV0() const {
return feature;
}

std::vector<std::vector<int>> Observation::ToFeaturesHan22V0() const {
const int num_row = 93;
const int num_tile_type = 34;
std::vector<std::vector<int>> feature(num_row,
std::vector<int>(num_tile_type));

const int obs_who = proto_.who();

// closed hand
{
for (auto t : proto_.private_observation().curr_hand().closed_tiles()) {
int tile_type = Tile(t).TypeUint();
for (int i = 0; i < 4; i++) {
if (feature[i][tile_type] == 0) {
feature[i][tile_type] = 1;
break;
}
}

if (Tile(t).IsRedFive()) {
feature[5][tile_type] = 1;
}
}
}

// events
{
for (int event_index = 0;
event_index < proto_.public_observation().events().size();
event_index++) {
const auto &event = proto_.public_observation().events()[event_index];

bool event_is_action = false;

// opens
if (event.type() == mjxproto::EVENT_TYPE_ADDED_KAN ||
event.type() == mjxproto::EVENT_TYPE_CHI ||
event.type() == mjxproto::EVENT_TYPE_CLOSED_KAN ||
event.type() == mjxproto::EVENT_TYPE_OPEN_KAN ||
event.type() == mjxproto::EVENT_TYPE_PON) {
event_is_action = true;
const int opens_offset = 6 + 6 * ((event.who() - obs_who + 4) % 4);

for (auto t : Open(event.open()).Tiles()) {
int tile_type = Tile(t).TypeUint();
for (int i = 0; i < 4; i++) {
if (feature[opens_offset + i][tile_type] == 0) {
feature[opens_offset + i][tile_type] = 1;
break;
}
}

if (Tile(t).IsRedFive()) {
feature[opens_offset + 5][tile_type] = 1;
}
}

if (Open(event.open()).Type() != OpenType::kKanClosed) {
int stolen_tile_type = Open(event.open()).StolenTile().TypeUint();
feature[opens_offset + 4][stolen_tile_type]++;
}
}

// discards
else if (event.type() == mjxproto::EVENT_TYPE_DISCARD ||
event.type() == mjxproto::EVENT_TYPE_TSUMOGIRI) {
event_is_action = true;
const int discard_offset = 30 + 10 * ((event.who() - obs_who + 4) % 4);
int tile_type = Tile(event.tile()).TypeUint();

for (int i = 0; i < 4; i++) {
if (feature[discard_offset + i][tile_type] == 0) {
feature[discard_offset + i][tile_type] = 1;
if (event.type() == mjxproto::EVENT_TYPE_DISCARD) {
feature[discard_offset + i + 4][tile_type] = 1;
}
break;
}
}

if (Tile(event.tile()).IsRedFive()) {
feature[discard_offset + 8][tile_type] = 1;
}

if (event_index > 0) {
const auto &event_before =
proto_.public_observation().events()[event_index - 1];
if (event_before.type() == mjxproto::EVENT_TYPE_RIICHI) {
feature[discard_offset + 9][tile_type] = 1;
}
}

if (event_index == proto_.public_observation().events().size() - 1 &&
event_is_action) {
int latest_event_offset = 80;
feature[latest_event_offset][tile_type] = 1;
}
}
}
}

// dora
{
const int dora_offset = 70;
int dora_index = 0;
for (auto dora_indicator : proto_.public_observation().dora_indicators()) {
int dora_indicator_tile_type = dora_indicator / 4;
int dora_tile_type = -1;
if (dora_indicator_tile_type < 9) {
dora_tile_type = (dora_indicator_tile_type + 1) % 9;
} else if (dora_indicator_tile_type < 18) {
dora_tile_type = (dora_indicator_tile_type - 9 + 1) % 9 + 9;
} else if (dora_indicator_tile_type < 27) {
dora_tile_type = (dora_indicator_tile_type - 18 + 1) % 9 + 18;
} else if (dora_indicator_tile_type < 31) {
dora_tile_type = (dora_indicator_tile_type - 27 + 1) % 4 + 27;
} else {
dora_tile_type = (dora_indicator_tile_type - 31 + 1) % 3 + 31;
}

feature[dora_offset + dora_index][dora_indicator_tile_type] = 1;
feature[dora_offset + dora_index + 4][dora_tile_type] = 1;

dora_index++;
if (dora_index > 3) break;
}
}

// wind
{
const int wind_offset = 78;
feature[wind_offset][27 + proto_.public_observation().init_score().round() /
4] = 1; // 27: EW

int self_wind =
(obs_who + proto_.public_observation().init_score().round()) / 4;
feature[wind_offset + 1][27 + self_wind] = 1;
}

// legal_actions
{
const int legal_action_offset = 81;
for (auto action : proto_.legal_actions()) {
if (action.type() == mjxproto::ACTION_TYPE_DISCARD) {
int tile_type = Tile(action.tile()).TypeUint();
feature[legal_action_offset][tile_type] = 1;
} else if (action.type() == mjxproto::ACTION_TYPE_CHI) {
auto open = Open(action.open());
int center_tile_type = open.Tiles()[1].TypeUint();
int stolen_tile_type = open.StolenTile().TypeUint();
feature[legal_action_offset + 2 + (stolen_tile_type - center_tile_type)]
[stolen_tile_type] = 1;
} else if (action.type() == mjxproto::ACTION_TYPE_PON) {
auto open = Open(action.open());
int stolen_tile_type = open.StolenTile().TypeUint();
feature[legal_action_offset + 4][stolen_tile_type] = 1;
} else if (action.type() == mjxproto::ACTION_TYPE_CLOSED_KAN) {
auto open = Open(action.open());
int stolen_tile_type = open.StolenTile().TypeUint();
feature[legal_action_offset + 5][stolen_tile_type] = 1;
} else if (action.type() == mjxproto::ACTION_TYPE_OPEN_KAN) {
auto open = Open(action.open());
int stolen_tile_type = open.StolenTile().TypeUint();
feature[legal_action_offset + 6][stolen_tile_type] = 1;
} else if (action.type() == mjxproto::ACTION_TYPE_ADDED_KAN) {
auto open = Open(action.open());
int stolen_tile_type = open.StolenTile().TypeUint();
feature[legal_action_offset + 7][stolen_tile_type] = 1;
} else if (action.type() == mjxproto::ACTION_TYPE_RIICHI) {
mjxproto::Observation next_proto;
mjxproto::Event riichi_event;
riichi_event.set_who(action.who());
riichi_event.set_type(mjxproto::EVENT_TYPE_RIICHI);

next_proto.CopyFrom(proto_);
next_proto.mutable_public_observation()->mutable_events()->Add(
std::move(riichi_event));
next_proto.clear_legal_actions();
auto with_legal_a = Observation(next_proto);
with_legal_a.add_legal_actions(
with_legal_a.GenerateLegalActions(std::move(next_proto)));

for (auto legal_action : with_legal_a.legal_actions()) {
int tile_type = Tile(legal_action.tile()).TypeUint();
feature[legal_action_offset + 8][tile_type] = 1;
}
} else if (action.type() == mjxproto::ACTION_TYPE_RON) {
int tile_type = Tile(action.tile()).TypeUint();
feature[legal_action_offset + 9][tile_type] = 1;
} else if (action.type() == mjxproto::ACTION_TYPE_TSUMO) {
int tile_type = Tile(action.tile()).TypeUint();
feature[legal_action_offset + 10][tile_type] = 1;
} else if (action.type() ==
mjxproto::ACTION_TYPE_ABORTIVE_DRAW_NINE_TERMINALS) {
for (auto t : proto_.private_observation().curr_hand().closed_tiles()) {
if (Tile(t).Is(TileSetType::kYaocyu)) {
int tile_type = Tile(t).TypeUint();
feature[legal_action_offset + 11][tile_type] = 1;
}
}
}
}
}

// index 4
// この特徴量は、リーチの仕様がmjxとpymahjongで異なるため無意味なものになっている
// ひとまずindex 4とindex 30は同一のものとしておく
{
for (int i = 0; i < num_tile_type; i++) {
feature[4][i] = feature[30][i];
}
}

return feature;
}

std::vector<mjxproto::Action> Observation::GenerateLegalActions(
const mjxproto::Observation &observation) {
auto obs = Observation(observation);
Expand Down
1 change: 1 addition & 0 deletions include/mjx/internal/observation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Observation {
const mjxproto::Observation& observation);

[[nodiscard]] std::vector<std::vector<int>> ToFeaturesSmallV0() const;
[[nodiscard]] std::vector<std::vector<int>> ToFeaturesHan22V0() const;

private:
// TODO: remove friends and use proto()
Expand Down
7 changes: 5 additions & 2 deletions include/mjx/observation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ bool Observation::operator!=(const Observation& other) const noexcept {
std::vector<std::vector<int>> Observation::ToFeatures2D(
const std::string& version) const noexcept {
auto obs = internal::Observation(proto_);
assert(version == "mjx-small-v0");
if (version == "mjx-small-v0") return obs.ToFeaturesSmallV0();
assert(version == "mjx-small-v0" || version == "han22-v0");
if (version == "mjx-small-v0")
return obs.ToFeaturesSmallV0();
else if (version == "han22-v0")
return obs.ToFeaturesHan22V0();
}

std::vector<Action> Observation::legal_actions() const noexcept {
Expand Down
Loading