From 8d7004f6f1dfd6df482674c84093945a916ed045 Mon Sep 17 00:00:00 2001 From: OkanoShinri Date: Wed, 24 Aug 2022 00:06:57 +0000 Subject: [PATCH 1/6] add ToFeaturesHan22V0() --- include/mjx/internal/observation.cpp | 216 +++++++++++++++++++++++++++ include/mjx/internal/observation.h | 1 + 2 files changed, 217 insertions(+) diff --git a/include/mjx/internal/observation.cpp b/include/mjx/internal/observation.cpp index 47220040..df799bec 100644 --- a/include/mjx/internal/observation.cpp +++ b/include/mjx/internal/observation.cpp @@ -231,6 +231,222 @@ std::vector> Observation::ToFeaturesSmallV0() const { return feature; } +std::vector> Observation::ToFeaturesHan22V0() const { + const int num_row = 93; + const int num_tile_type = 34; + std::vector> feature(num_row, + std::vector(num_tile_type)); + + const int obs_who = proto_.who(); + + // closed hand + { + for (auto t : proto_.private_observation().curr_hand().closed_tiles()) { + int tile_type = Tile(t).TypeUint(); + for (int i = 0; i < 4; i++) { + if (feature[i][tile_type] == 0) { + feature[i][tile_type] = 1; + break; + } + } + + if (Tile(t).IsRedFive()) { + feature[5][tile_type] = 1; + } + } + } + + // events + { + for (int event_index = 0; + event_index < proto_.public_observation().events().size(); + event_index++) { + const auto &event = proto_.public_observation().events()[event_index]; + + bool event_is_action = false; + + // opens + if (event.type() == mjxproto::EVENT_TYPE_ADDED_KAN || + event.type() == mjxproto::EVENT_TYPE_CHI || + event.type() == mjxproto::EVENT_TYPE_CLOSED_KAN || + event.type() == mjxproto::EVENT_TYPE_OPEN_KAN || + event.type() == mjxproto::EVENT_TYPE_PON) { + event_is_action = true; + const int opens_offset = 6 + 6 * ((event.who() - obs_who + 4) % 4); + + for (auto t : Open(event.open()).Tiles()) { + int tile_type = Tile(t).TypeUint(); + for (int i = 0; i < 4; i++) { + if (feature[opens_offset + i][tile_type] == 0) { + feature[opens_offset + i][tile_type] = 1; + break; + } + } + + if (Tile(t).IsRedFive()) { + feature[opens_offset + 5][tile_type] = 1; + } + } + + if (Open(event.open()).Type() != OpenType::kKanClosed) { + int stolen_tile_type = Open(event.open()).StolenTile().TypeUint(); + feature[opens_offset + 4][stolen_tile_type]++; + } + } + + // discards + else if (event.type() == mjxproto::EVENT_TYPE_DISCARD || + event.type() == mjxproto::EVENT_TYPE_TSUMOGIRI) { + event_is_action = true; + const int discard_offset = 30 + 10 * ((event.who() - obs_who + 4) % 4); + int tile_type = Tile(event.tile()).TypeUint(); + + for (int i = 0; i < 4; i++) { + if (feature[discard_offset + i][tile_type] == 0) { + feature[discard_offset + i][tile_type] = 1; + if (event.type() == mjxproto::EVENT_TYPE_DISCARD) { + feature[discard_offset + i + 4][tile_type] = 1; + } + break; + } + } + + if (Tile(event.tile()).IsRedFive()) { + feature[discard_offset + 8][tile_type] = 1; + } + + if (event_index > 0) { + const auto &event_before = + proto_.public_observation().events()[event_index - 1]; + if (event_before.type() == mjxproto::EVENT_TYPE_RIICHI) { + feature[discard_offset + 9][tile_type] = 1; + } + } + + if (event_index == proto_.public_observation().events().size() - 1 && + event_is_action) { + int latest_event_offset = 80; + feature[latest_event_offset][tile_type] = 1; + } + } + } + } + + // dora + { + const int dora_offset = 70; + int dora_index = 0; + for (auto dora_indicator : proto_.public_observation().dora_indicators()) { + int dora_indicator_tile_type = dora_indicator / 4; + int dora_tile_type = -1; + if (dora_indicator_tile_type < 9) { + dora_tile_type = (dora_indicator_tile_type + 1) % 9; + } else if (dora_indicator_tile_type < 18) { + dora_tile_type = (dora_indicator_tile_type + 1) % 9 + 9; + } else if (dora_indicator_tile_type < 27) { + dora_tile_type = (dora_indicator_tile_type + 1) % 9 + 18; + } else if (dora_indicator_tile_type < 31) { + dora_tile_type = (dora_indicator_tile_type + 1) % 4 + 27; + } else { + dora_tile_type = (dora_indicator_tile_type + 1) % 3 + 31; + } + + feature[dora_offset + dora_index][dora_indicator_tile_type] = 1; + feature[dora_offset + dora_index + 4][dora_tile_type] = 1; + + dora_index++; + if (dora_index > 3) break; + } + } + + // wind + { + const int wind_offset = 78; + feature[wind_offset][27 + proto_.public_observation().init_score().round() / + 4] = 1; // 27: EW + + int self_wind = + (obs_who + proto_.public_observation().init_score().round()) / 4; + feature[wind_offset + 1][27 + self_wind] = 1; + } + + // legal_actions + { + const int legal_action_offset = 81; + for (auto action : proto_.legal_actions()) { + if (action.type() == mjxproto::ACTION_TYPE_DISCARD) { + int tile_type = Tile(action.tile()).TypeUint(); + feature[legal_action_offset][tile_type] = 1; + } else if (action.type() == mjxproto::ACTION_TYPE_CHI) { + auto open = Open(action.open()); + int center_tile_type = open.Tiles()[1].TypeUint(); + int stolen_tile_type = open.StolenTile().TypeUint(); + feature[legal_action_offset + 2 + (center_tile_type - stolen_tile_type)] + [stolen_tile_type] = 1; + } else if (action.type() == mjxproto::ACTION_TYPE_PON) { + auto open = Open(action.open()); + int stolen_tile_type = open.StolenTile().TypeUint(); + feature[legal_action_offset + 4][stolen_tile_type] = 1; + } else if (action.type() == mjxproto::ACTION_TYPE_CLOSED_KAN) { + auto open = Open(action.open()); + int stolen_tile_type = open.StolenTile().TypeUint(); + feature[legal_action_offset + 5][stolen_tile_type] = 1; + } else if (action.type() == mjxproto::ACTION_TYPE_OPEN_KAN) { + auto open = Open(action.open()); + int stolen_tile_type = open.StolenTile().TypeUint(); + feature[legal_action_offset + 6][stolen_tile_type] = 1; + } else if (action.type() == mjxproto::ACTION_TYPE_ADDED_KAN) { + auto open = Open(action.open()); + int stolen_tile_type = open.StolenTile().TypeUint(); + feature[legal_action_offset + 7][stolen_tile_type] = 1; + } else if (action.type() == mjxproto::ACTION_TYPE_RIICHI) { + mjxproto::Observation next_proto; + mjxproto::Event riichi_event; + riichi_event.set_who(action.who()); + riichi_event.set_type(mjxproto::EVENT_TYPE_RIICHI); + + next_proto.CopyFrom(proto_); + next_proto.mutable_public_observation()->mutable_events()->Add( + std::move(riichi_event)); + next_proto.clear_legal_actions(); + auto with_legal_a = Observation(next_proto); + with_legal_a.add_legal_actions( + with_legal_a.GenerateLegalActions(std::move(next_proto))); + + for (auto legal_action : with_legal_a.legal_actions()) { + int tile_type = Tile(legal_action.tile()).TypeUint(); + feature[legal_action_offset + 8][tile_type] = 1; + } + } else if (action.type() == mjxproto::ACTION_TYPE_RON) { + int tile_type = Tile(action.tile()).TypeUint(); + feature[legal_action_offset + 9][tile_type] = 1; + } else if (action.type() == mjxproto::ACTION_TYPE_TSUMO) { + int tile_type = Tile(action.tile()).TypeUint(); + feature[legal_action_offset + 10][tile_type] = 1; + } else if (action.type() == + mjxproto::ACTION_TYPE_ABORTIVE_DRAW_NINE_TERMINALS) { + for (auto t : proto_.private_observation().curr_hand().closed_tiles()) { + if (Tile(t).Is(TileSetType::kYaocyu)) { + int tile_type = Tile(t).TypeUint(); + feature[legal_action_offset + 11][tile_type] = 1; + } + } + } + } + } + + // index 4 + // この特徴量は、リーチの仕様がmjxとpymahjongで異なるため無意味なものになっている + // ひとまずindex 4とindex 30は同一のものとしておく + { + for (int i = 0; i < num_tile_type; i++) { + feature[4][i] = feature[30][i]; + } + } + + return feature; +} + std::vector Observation::GenerateLegalActions( const mjxproto::Observation &observation) { auto obs = Observation(observation); diff --git a/include/mjx/internal/observation.h b/include/mjx/internal/observation.h index ac8f3f51..e08f41e5 100644 --- a/include/mjx/internal/observation.h +++ b/include/mjx/internal/observation.h @@ -32,6 +32,7 @@ class Observation { const mjxproto::Observation& observation); [[nodiscard]] std::vector> ToFeaturesSmallV0() const; + [[nodiscard]] std::vector> ToFeaturesHan22V0() const; private: // TODO: remove friends and use proto() From fbde6dbafb3c790cb0b5c7a5b75b4a4c4c6560c7 Mon Sep 17 00:00:00 2001 From: OkanoShinri Date: Wed, 24 Aug 2022 00:10:09 +0000 Subject: [PATCH 2/6] add han22 to observation.cpp --- include/mjx/observation.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/mjx/observation.cpp b/include/mjx/observation.cpp index 2de5758b..80aea285 100644 --- a/include/mjx/observation.cpp +++ b/include/mjx/observation.cpp @@ -42,8 +42,9 @@ bool Observation::operator!=(const Observation& other) const noexcept { std::vector> Observation::ToFeatures2D( const std::string& version) const noexcept { auto obs = internal::Observation(proto_); - assert(version == "mjx-small-v0"); + assert(version == "mjx-small-v0" || version == "mjx-han22-v0"); if (version == "mjx-small-v0") return obs.ToFeaturesSmallV0(); + else if (version == "mjx-han22-v0") return obs.ToFeaturesHan22V0(); } std::vector Observation::legal_actions() const noexcept { From 41746bec9f94345ee9cc99453f5637b12e3cb044 Mon Sep 17 00:00:00 2001 From: OkanoShinri Date: Wed, 24 Aug 2022 00:12:22 +0000 Subject: [PATCH 3/6] rename feature --- include/mjx/observation.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/mjx/observation.cpp b/include/mjx/observation.cpp index 80aea285..ade6481f 100644 --- a/include/mjx/observation.cpp +++ b/include/mjx/observation.cpp @@ -42,9 +42,9 @@ bool Observation::operator!=(const Observation& other) const noexcept { std::vector> Observation::ToFeatures2D( const std::string& version) const noexcept { auto obs = internal::Observation(proto_); - assert(version == "mjx-small-v0" || version == "mjx-han22-v0"); + assert(version == "mjx-small-v0" || version == "han22-v0"); if (version == "mjx-small-v0") return obs.ToFeaturesSmallV0(); - else if (version == "mjx-han22-v0") return obs.ToFeaturesHan22V0(); + else if (version == "han22-v0") return obs.ToFeaturesHan22V0(); } std::vector Observation::legal_actions() const noexcept { From 4e5948efe60f2455d0c52115b0488bfba0c2c4f7 Mon Sep 17 00:00:00 2001 From: OkanoShinri Date: Wed, 24 Aug 2022 00:15:42 +0000 Subject: [PATCH 4/6] add han22 to observation.py --- mjx/observation.py | 240 +-------------------------------------------- 1 file changed, 3 insertions(+), 237 deletions(-) diff --git a/mjx/observation.py b/mjx/observation.py index 9aa6f632..0bd4ef1e 100644 --- a/mjx/observation.py +++ b/mjx/observation.py @@ -111,12 +111,11 @@ def show_svg(self, view_idx: Optional[int] = None) -> None: def to_features(self, feature_name: str): assert feature_name in ("mjx-small-v0", "han22-v0") - if feature_name == "han22-v0": - feature = self._get_han22_features() - return feature - assert self._cpp_obj is not None # TODO: use ndarray in C++ side + if feature_name == "han22-v0": + feature = np.array(self._cpp_obj.to_features_2d(feature_name), dtype=np.bool8) # type: ignore + return feature feature = np.array(self._cpp_obj.to_features_2d(feature_name), dtype=np.int32) # type: ignore return feature @@ -134,236 +133,3 @@ def _from_cpp_obj(cls, cpp_obj) -> Observation: obs = cls() obs._cpp_obj = cpp_obj return obs - - def _get_han22_features(self) -> np.ndarray: - feature = np.full((93, 34), False, dtype=bool) - proto = self.to_proto() - mj_table = MahjongTable.from_proto(proto) - - closed_tiles_ = list( - filter( - lambda tile_unit: tile_unit.tile_unit_type == EventType.DRAW, - mj_table.players[self.who()].tile_units, - ) - )[0] - closed_tiles_id = [tile.id() for tile in closed_tiles_.tiles] - closed_tiles_type = [id // 4 for id in closed_tiles_id] - - for tiletype in range(34): # tiletype: 0~33 - - # 0-5 - in_hand = closed_tiles_type.count(tiletype) - feature[0][tiletype] = in_hand > 0 - feature[1][tiletype] = in_hand > 1 - feature[2][tiletype] = in_hand > 2 - feature[3][tiletype] = in_hand == 4 - - for event in proto.public_observation.events: - if ( - ( - event.type == mjxproto.EVENT_TYPE_DISCARD - or event.type == mjxproto.EVENT_TYPE_TSUMOGIRI - ) - and event.who == proto.who - and event.tile // 4 == tiletype - ): - feature[4][tiletype] = True - break - - feature[5][tiletype] = tiletype in [4, 13, 22] and ( - tiletype * 34 in closed_tiles_id - or tiletype * 34 + 1 in closed_tiles_id - or tiletype * 34 + 2 in closed_tiles_id - or tiletype * 34 + 3 in closed_tiles_id - ) - - # 6-29 - for j in range(4): - player_id = (self.who() + j) % 4 - _calling_of_player_j = self._calling_of_player_i(tiletype, player_id, mj_table) - for k in range(6): - feature[6 + j * 6 + k][tiletype] = _calling_of_player_j[k] - - # 30-69 - for j in range(4): - player_id = (self.who() + j) % 4 - _discarded_tiles_from_player_j = self._discarded_tiles_from_player_i( - tiletype, player_id, mj_table - ) - for k in range(10): - feature[30 + j * 10 + k][tiletype] = _discarded_tiles_from_player_j[k] - - # 70-79 - for j in range(len(mj_table.doras)): - feature[70 + j][tiletype] = (mj_table.doras[j]) // 4 - 1 == tiletype - feature[74 + j][tiletype] = (mj_table.doras[j]) // 4 == tiletype - feature[78][tiletype] = [27, 28, 29, 30][ - (mj_table.round - 1) // 4 - ] == tiletype # 27=EW,28=SW,roundは1,2,3,.. - feature[79][tiletype] = [27, 28, 29, 30][mj_table.players[0].wind] == tiletype - - # 80 - if mj_table.latest_tile is not None: - feature[80][tiletype] = mj_table.latest_tile // 4 == tiletype - - # 82-84 - if tiletype <= 26: - if tiletype % 9 < 7: - feature[81][tiletype] = ( - tiletype + 1 in closed_tiles_type and tiletype + 2 in closed_tiles_type - ) - if 0 < tiletype % 9 < 8: - feature[82][tiletype] = ( - tiletype - 1 in closed_tiles_type and tiletype + 1 in closed_tiles_type - ) - if 1 < tiletype % 9: - feature[83][tiletype] = ( - tiletype - 2 in closed_tiles_type and tiletype - 1 in closed_tiles_type - ) - - # 85-92,81 - _information_for_available_actions = self._information_for_available_actions( - tiletype, proto - ) - for j in range(len(_information_for_available_actions) - 1): - feature[85 + j][tiletype] = _information_for_available_actions[j] - feature[92][tiletype] = feature[92][tiletype] and feature[0][tiletype] - feature[81][tiletype] = _information_for_available_actions[8] - - return feature - - def _calling_of_player_i(self, tile_type: int, player_id: int, mj_table: MahjongTable): - feature = [False] * 6 - tile_units = mj_table.players[player_id].tile_units - open_tile_units = list( - filter( - lambda tile_unit: tile_unit.tile_unit_type != EventType.DRAW - and tile_unit.tile_unit_type != EventType.DISCARD, - tile_units, - ) - ) - open_tiles_id_ = [ - [tile.id() for tile in open_tile_unit.tiles] for open_tile_unit in open_tile_units - ] - open_tiles_id = list(itertools.chain.from_iterable(open_tiles_id_)) - open_tiles_type = [id // 4 for id in open_tiles_id] - - in_furo = open_tiles_type.count(tile_type) - feature[0] = in_furo > 0 - feature[1] = in_furo > 1 - feature[2] = in_furo > 2 - feature[3] = in_furo == 4 - - stolen_tiles = [open_tile_unit.tiles[0].id() // 4 for open_tile_unit in open_tile_units] - feature[4] = tile_type in stolen_tiles - - feature[5] = tile_type in [4, 13, 22] and ( - tile_type * 34 in open_tiles_id - or tile_type * 34 + 1 in open_tiles_id - or tile_type * 34 + 2 in open_tiles_id - or tile_type * 34 + 3 in open_tiles_id - ) - - return feature - - def _discarded_tiles_from_player_i( - self, tile_type: int, player_id: int, mj_table: MahjongTable - ): - feature = [False] * 10 - tile_units = mj_table.players[player_id].tile_units - discard_tile_unit = list( - filter( - lambda tile_unit: tile_unit.tile_unit_type == EventType.DISCARD, - tile_units, - ) - )[0] - discard_tiles_id = [tile.id() for tile in discard_tile_unit.tiles] - discard_tiles_type = [id // 4 for id in discard_tiles_id] - - tile_in_discard = list( - filter( - lambda tile: tile.id() // 4 == tile_type, - discard_tile_unit.tiles, - ) - ) - - in_discard = discard_tiles_type.count(tile_type) - feature[0] = in_discard > 0 - feature[1] = in_discard > 1 - feature[2] = in_discard > 2 - feature[3] = in_discard == 4 - - for j in range(len(tile_in_discard)): - feature[4 + j] = not tile_in_discard[j].is_tsumogiri - if tile_in_discard[j].with_riichi: - feature[9] = True - - feature[8] = tile_type in [4, 13, 22] and ( - tile_type * 34 in discard_tiles_id - or tile_type * 34 + 1 in discard_tiles_id - or tile_type * 34 + 2 in discard_tiles_id - or tile_type * 34 + 3 in discard_tiles_id - ) - - return feature - - def _information_for_available_actions(self, tile_type: int, proto): - feature = [False] * (8 + 1) # 最後の1は81番用 - obs = Observation.from_proto(proto) - try: - legal_actions = Observation(obs.add_legal_actions(obs.to_json())).legal_actions() - except AssertionError: - legal_actions = obs.legal_actions() - - for action in legal_actions: - if ( - action.type() == ActionType.PON - and action.tile() is not None - and action.tile().type() == tile_type - ): - feature[0] = True - if ( - action.type() == ActionType.CLOSED_KAN - and action.tile() is not None - and action.tile().type() == tile_type - ): - feature[1] = True - if ( - action.type() == ActionType.OPEN_KAN - and action.tile() is not None - and action.tile().type() == tile_type - ): - feature[2] = True - if ( - action.type() == ActionType.ADDED_KAN - and action.tile() is not None - and action.tile().type() == tile_type - ): - feature[3] = True - if ( - action.type() == ActionType.RIICHI - and action.tile() is not None - and action.tile().type() == tile_type - ): - feature[4] = True - if ( - action.type() == ActionType.RON - and action.tile() is not None - and action.tile().type() == tile_type - ): - feature[5] = True - if ( - action.type() == ActionType.TSUMO - and action.tile() is not None - and action.tile().type() == tile_type - ): - feature[6] = True - if action.type() == ActionType.ABORTIVE_DRAW_NINE_TERMINALS: - feature[7] = True - if ( - action.type() == ActionType.DISCARD - and action.tile() is not None - and action.tile().type() == tile_type - ): - feature[8] = True - return feature From 5dff6c5fca2c36b60727af60d47bcf89dd77d6ee Mon Sep 17 00:00:00 2001 From: OkanoShinri Date: Wed, 24 Aug 2022 01:04:04 +0000 Subject: [PATCH 5/6] fix --- include/mjx/internal/observation.cpp | 10 +++++----- include/mjx/observation.cpp | 6 ++++-- mjx/observation.py | 4 +--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/include/mjx/internal/observation.cpp b/include/mjx/internal/observation.cpp index df799bec..a5f79a00 100644 --- a/include/mjx/internal/observation.cpp +++ b/include/mjx/internal/observation.cpp @@ -342,13 +342,13 @@ std::vector> Observation::ToFeaturesHan22V0() const { if (dora_indicator_tile_type < 9) { dora_tile_type = (dora_indicator_tile_type + 1) % 9; } else if (dora_indicator_tile_type < 18) { - dora_tile_type = (dora_indicator_tile_type + 1) % 9 + 9; + dora_tile_type = (dora_indicator_tile_type - 9 + 1) % 9 + 9; } else if (dora_indicator_tile_type < 27) { - dora_tile_type = (dora_indicator_tile_type + 1) % 9 + 18; + dora_tile_type = (dora_indicator_tile_type - 18 + 1) % 9 + 18; } else if (dora_indicator_tile_type < 31) { - dora_tile_type = (dora_indicator_tile_type + 1) % 4 + 27; + dora_tile_type = (dora_indicator_tile_type - 27 + 1) % 4 + 27; } else { - dora_tile_type = (dora_indicator_tile_type + 1) % 3 + 31; + dora_tile_type = (dora_indicator_tile_type - 31 + 1) % 3 + 31; } feature[dora_offset + dora_index][dora_indicator_tile_type] = 1; @@ -381,7 +381,7 @@ std::vector> Observation::ToFeaturesHan22V0() const { auto open = Open(action.open()); int center_tile_type = open.Tiles()[1].TypeUint(); int stolen_tile_type = open.StolenTile().TypeUint(); - feature[legal_action_offset + 2 + (center_tile_type - stolen_tile_type)] + feature[legal_action_offset + 2 + (stolen_tile_type - center_tile_type)] [stolen_tile_type] = 1; } else if (action.type() == mjxproto::ACTION_TYPE_PON) { auto open = Open(action.open()); diff --git a/include/mjx/observation.cpp b/include/mjx/observation.cpp index ade6481f..301b7c9f 100644 --- a/include/mjx/observation.cpp +++ b/include/mjx/observation.cpp @@ -43,8 +43,10 @@ std::vector> Observation::ToFeatures2D( const std::string& version) const noexcept { auto obs = internal::Observation(proto_); assert(version == "mjx-small-v0" || version == "han22-v0"); - if (version == "mjx-small-v0") return obs.ToFeaturesSmallV0(); - else if (version == "han22-v0") return obs.ToFeaturesHan22V0(); + if (version == "mjx-small-v0") + return obs.ToFeaturesSmallV0(); + else if (version == "han22-v0") + return obs.ToFeaturesHan22V0(); } std::vector Observation::legal_actions() const noexcept { diff --git a/mjx/observation.py b/mjx/observation.py index 0bd4ef1e..d11dea2e 100644 --- a/mjx/observation.py +++ b/mjx/observation.py @@ -1,6 +1,5 @@ from __future__ import annotations -import itertools from typing import List, Optional import _mjx # type: ignore @@ -9,12 +8,11 @@ import mjxproto from mjx.action import Action -from mjx.const import ActionType, EventType, PlayerIdx, TileType +from mjx.const import PlayerIdx, TileType from mjx.event import Event from mjx.hand import Hand from mjx.tile import Tile from mjx.visualizer.svg import save_svg, show_svg, to_svg -from mjx.visualizer.visualizer import MahjongTable class Observation: From da18921acca1ed2daefba80395ea609dca7a57bc Mon Sep 17 00:00:00 2001 From: OkanoShinri Date: Wed, 24 Aug 2022 01:04:19 +0000 Subject: [PATCH 6/6] add tests --- tests_py/test_features.py | 60 ++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/tests_py/test_features.py b/tests_py/test_features.py index 9cc72f7e..0a8711bf 100644 --- a/tests_py/test_features.py +++ b/tests_py/test_features.py @@ -47,23 +47,23 @@ def test_table_wind(): def test_dora_indicator(): - json_str = '{"publicObservation":{"playerIds":["player_2","player_1","player_0","player_3"],"initScore":{"tens":[25000,25000,25000,25000]},"doraIndicators":[101],"events":[{"type":"EVENT_TYPE_DRAW"}]},"privateObservation":{"initHand":{"closedTiles":[24,3,87,124,37,42,58,134,92,82,122,18,117]},"drawHistory":[79],"currHand":{"closedTiles":[3,18,24,37,42,58,79,82,87,92,117,122,124,134]}},"legalActions":[{"tile":3},{"tile":18},{"tile":24},{"tile":37},{"tile":42},{"tile":58},{"type":"ACTION_TYPE_TSUMOGIRI","tile":79},{"tile":82},{"tile":87},{"tile":92},{"tile":117},{"tile":122},{"tile":124},{"tile":134}]}' + json_str = '{"who":3,"publicObservation":{"playerIds":["player_2","player_1","player_3","player_0"],"initScore":{"round":7,"honba":7,"tens":[31000,23000,23000,23000]},"doraIndicators":[106],"events":[{"type":"EVENT_TYPE_DRAW","who":3}]},"privateObservation":{"who":3,"initHand":{"closedTiles":[95,77,74,4,85,70,30,66,31,59,102,84,78]},"drawHistory":[37],"currHand":{"closedTiles":[4,30,31,37,59,66,70,74,77,78,84,85,95,102]}},"legalActions":[{"who":3,"tile":4},{"who":3,"tile":30},{"type":"ACTION_TYPE_TSUMOGIRI","who":3,"tile":37},{"who":3,"tile":59},{"who":3,"tile":66},{"who":3,"tile":70},{"who":3,"tile":74},{"who":3,"tile":77},{"who":3,"tile":84},{"who":3,"tile":95},{"who":3,"tile":102}]}' obs = Observation(json_str) feature = obs.to_features("han22-v0") # index70: If t is Dora indicator - # 24: 7s - assert feature[70][24] + # 26: 9s + assert feature[70][26] def test_dora(): - json_str = '{"publicObservation":{"playerIds":["player_2","player_1","player_0","player_3"],"initScore":{"tens":[25000,25000,25000,25000]},"doraIndicators":[101],"events":[{"type":"EVENT_TYPE_DRAW"}]},"privateObservation":{"initHand":{"closedTiles":[24,3,87,124,37,42,58,134,92,82,122,18,117]},"drawHistory":[79],"currHand":{"closedTiles":[3,18,24,37,42,58,79,82,87,92,117,122,124,134]}},"legalActions":[{"tile":3},{"tile":18},{"tile":24},{"tile":37},{"tile":42},{"tile":58},{"type":"ACTION_TYPE_TSUMOGIRI","tile":79},{"tile":82},{"tile":87},{"tile":92},{"tile":117},{"tile":122},{"tile":124},{"tile":134}]}' + json_str = '{"who":3,"publicObservation":{"playerIds":["player_2","player_1","player_3","player_0"],"initScore":{"round":7,"honba":7,"tens":[31000,23000,23000,23000]},"doraIndicators":[106],"events":[{"type":"EVENT_TYPE_DRAW","who":3}]},"privateObservation":{"who":3,"initHand":{"closedTiles":[95,77,74,4,85,70,30,66,31,59,102,84,78]},"drawHistory":[37],"currHand":{"closedTiles":[4,30,31,37,59,66,70,74,77,78,84,85,95,102]}},"legalActions":[{"who":3,"tile":4},{"who":3,"tile":30},{"type":"ACTION_TYPE_TSUMOGIRI","who":3,"tile":37},{"who":3,"tile":59},{"who":3,"tile":66},{"who":3,"tile":70},{"who":3,"tile":74},{"who":3,"tile":77},{"who":3,"tile":84},{"who":3,"tile":95},{"who":3,"tile":102}]}' obs = Observation(json_str) feature = obs.to_features("han22-v0") # index74: If t is Dora - # 24: 8s - assert feature[74][25] + # 18: 1s + assert feature[74][18] def test_closed_tiles(): @@ -71,8 +71,6 @@ def test_closed_tiles(): obs = Observation(json_str) feature = obs.to_features("han22-v0") - # index0: If player 0 has >= 1 t in hand - closed = feature[0] val = [ True, False, @@ -109,10 +107,22 @@ def test_closed_tiles(): False, True, ] - assert all(closed == val) + # index0: If player 0 has >= 1 t in hand + assert all(feature[0] == val) + + +def test_three_tiles_in_hand(): + json_str = '{"who":1,"publicObservation":{"playerIds":["player_1","player_2","player_0","player_3"],"initScore":{"tens":[25000,25000,25000,25000]},"doraIndicators":[6],"events":[{"type":"EVENT_TYPE_DRAW"},{"tile":27},{"type":"EVENT_TYPE_CHI","who":1,"open":16631},{"who":1,"tile":41},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":131},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":120},{"type":"EVENT_TYPE_DRAW"},{"tile":35},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":107},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":74},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":46},{"type":"EVENT_TYPE_DRAW"},{"tile":127},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":4},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":28},{"type":"EVENT_TYPE_DRAW","who":3},{"type":"EVENT_TYPE_TSUMOGIRI","who":3,"tile":1},{"type":"EVENT_TYPE_DRAW"},{"tile":75},{"type":"EVENT_TYPE_CHI","who":1,"open":43263},{"who":1,"tile":16},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":63},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":2},{"type":"EVENT_TYPE_DRAW"},{"tile":94},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":82},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":117},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":62},{"type":"EVENT_TYPE_DRAW"},{"tile":31},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":129},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":135},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":36},{"type":"EVENT_TYPE_DRAW"},{"tile":80},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":14},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":126},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":21},{"type":"EVENT_TYPE_DRAW"},{"tile":57},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":64},{"type":"EVENT_TYPE_DRAW","who":2},{"type":"EVENT_TYPE_TSUMOGIRI","who":2,"tile":130},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":47},{"type":"EVENT_TYPE_DRAW"},{"tile":110},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":66},{"type":"EVENT_TYPE_DRAW","who":2},{"type":"EVENT_TYPE_TSUMOGIRI","who":2,"tile":70},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":92},{"type":"EVENT_TYPE_DRAW"},{"tile":105},{"type":"EVENT_TYPE_DRAW","who":1}]},"privateObservation":{"who":1,"initHand":{"closedTiles":[81,107,79,82,41,99,4,14,22,97,29,66,129]},"drawHistory":[16,39,64,122,30,98,119,15],"currHand":{"closedTiles":[15,30,39,97,98,99,119,122],"opens":[16631,43263]}},"legalActions":[{"type":"ACTION_TYPE_TSUMOGIRI","who":1,"tile":15},{"who":1,"tile":30},{"who":1,"tile":39},{"who":1,"tile":97},{"who":1,"tile":119},{"who":1,"tile":122}]}' + obs = Observation(json_str) + feature = obs.to_features("han22-v0") + + # index2: If player 0 has >= 2 t in hand + # 24: 2p + assert feature[2][24] -def test_riichi(): + +def test_discarded_with_riichi(): json_str = '{"who":1,"publicObservation":{"playerIds":["rule-based-0","target-player","rule-based-2","rule-based-3"],"initScore":{"round":2,"tens":[25200,24000,27300,23500]},"doraIndicators":[88],"events":[{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":108},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":119},{"type":"EVENT_TYPE_DRAW"},{"tile":116},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":110},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":114},{"type":"EVENT_TYPE_DRAW","who":3},{"type":"EVENT_TYPE_TSUMOGIRI","who":3,"tile":111},{"type":"EVENT_TYPE_DRAW"},{"type":"EVENT_TYPE_TSUMOGIRI","tile":134},{"type":"EVENT_TYPE_DRAW","who":1},{"type":"EVENT_TYPE_TSUMOGIRI","who":1,"tile":122},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":39},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":84},{"type":"EVENT_TYPE_DRAW"},{"type":"EVENT_TYPE_TSUMOGIRI","tile":123},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":124},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":106},{"type":"EVENT_TYPE_DRAW","who":3},{"type":"EVENT_TYPE_TSUMOGIRI","who":3,"tile":25},{"type":"EVENT_TYPE_DRAW"},{"tile":37},{"type":"EVENT_TYPE_DRAW","who":1},{"type":"EVENT_TYPE_TSUMOGIRI","who":1,"tile":117},{"type":"EVENT_TYPE_DRAW","who":2},{"type":"EVENT_TYPE_TSUMOGIRI","who":2,"tile":36},{"type":"EVENT_TYPE_DRAW","who":3},{"type":"EVENT_TYPE_TSUMOGIRI","who":3,"tile":118},{"type":"EVENT_TYPE_DRAW"},{"tile":68},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":41},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":51},{"type":"EVENT_TYPE_PON","who":2,"open":19465},{"who":2,"tile":4},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":100},{"type":"EVENT_TYPE_CHI","open":60463},{"tile":42},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":69},{"type":"EVENT_TYPE_DRAW","who":2},{"type":"EVENT_TYPE_TSUMOGIRI","who":2,"tile":7},{"type":"EVENT_TYPE_PON","open":2570},{"tile":54},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":79},{"type":"EVENT_TYPE_DRAW","who":2},{"type":"EVENT_TYPE_TSUMOGIRI","who":2,"tile":120},{"type":"EVENT_TYPE_DRAW","who":3},{"type":"EVENT_TYPE_TSUMOGIRI","who":3,"tile":78},{"type":"EVENT_TYPE_DRAW"},{"type":"EVENT_TYPE_TSUMOGIRI","tile":1},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":85},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":43},{"type":"EVENT_TYPE_DRAW","who":3},{"type":"EVENT_TYPE_TSUMOGIRI","who":3,"tile":121},{"type":"EVENT_TYPE_DRAW"},{"tile":24},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":95},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":57},{"type":"EVENT_TYPE_DRAW","who":3},{"type":"EVENT_TYPE_RIICHI","who":3},{"who":3,"tile":59},{"type":"EVENT_TYPE_RIICHI_SCORE_CHANGE","who":3},{"type":"EVENT_TYPE_DRAW"},{"type":"EVENT_TYPE_TSUMOGIRI","tile":115},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":101},{"type":"EVENT_TYPE_DRAW","who":2},{"type":"EVENT_TYPE_TSUMOGIRI","who":2,"tile":70},{"type":"EVENT_TYPE_DRAW","who":3},{"type":"EVENT_TYPE_TSUMOGIRI","who":3,"tile":104},{"type":"EVENT_TYPE_DRAW"},{"type":"EVENT_TYPE_TSUMOGIRI","tile":112},{"type":"EVENT_TYPE_DRAW","who":1},{"type":"EVENT_TYPE_TSUMOGIRI","who":1,"tile":83},{"type":"EVENT_TYPE_PON","who":2,"open":31787},{"who":2,"tile":33},{"type":"EVENT_TYPE_DRAW","who":3},{"type":"EVENT_TYPE_TSUMO","who":3,"tile":67}]},"privateObservation":{"who":1,"initHand":{"closedTiles":[129,101,63,128,31,41,10,21,124,69,14,110,79]},"drawHistory":[26,122,95,117,55,85,58,13,61,44,83],"currHand":{"closedTiles":[10,13,14,21,26,31,44,55,58,61,63,128,129]}},"roundTerminal":{"finalScore":{"round":2,"tens":[22200,21000,21300,35500]},"wins":[{"who":3,"fromWho":3,"hand":{"closedTiles":[20,23,46,47,60,62,65,67,73,75,96,99,133,135]},"winTile":67,"fu":25,"ten":12000,"tenChanges":[-3000,-3000,-6000,13000],"yakus":[0,1,22,53],"fans":[1,1,2,2],"uraDoraIndicators":[16]}]},"legalActions":[{"who":1,"type":"ACTION_TYPE_DUMMY"}]}' obs = Observation(json_str) feature = obs.to_features("han22-v0") @@ -122,14 +132,24 @@ def test_riichi(): assert feature[59][14] -def test_three_tiles_in_hand(): - json_str = '{"who":1,"publicObservation":{"playerIds":["player_1","player_2","player_0","player_3"],"initScore":{"tens":[25000,25000,25000,25000]},"doraIndicators":[6],"events":[{"type":"EVENT_TYPE_DRAW"},{"tile":27},{"type":"EVENT_TYPE_CHI","who":1,"open":16631},{"who":1,"tile":41},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":131},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":120},{"type":"EVENT_TYPE_DRAW"},{"tile":35},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":107},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":74},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":46},{"type":"EVENT_TYPE_DRAW"},{"tile":127},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":4},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":28},{"type":"EVENT_TYPE_DRAW","who":3},{"type":"EVENT_TYPE_TSUMOGIRI","who":3,"tile":1},{"type":"EVENT_TYPE_DRAW"},{"tile":75},{"type":"EVENT_TYPE_CHI","who":1,"open":43263},{"who":1,"tile":16},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":63},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":2},{"type":"EVENT_TYPE_DRAW"},{"tile":94},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":82},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":117},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":62},{"type":"EVENT_TYPE_DRAW"},{"tile":31},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":129},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":135},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":36},{"type":"EVENT_TYPE_DRAW"},{"tile":80},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":14},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":126},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":21},{"type":"EVENT_TYPE_DRAW"},{"tile":57},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":64},{"type":"EVENT_TYPE_DRAW","who":2},{"type":"EVENT_TYPE_TSUMOGIRI","who":2,"tile":130},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":47},{"type":"EVENT_TYPE_DRAW"},{"tile":110},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":66},{"type":"EVENT_TYPE_DRAW","who":2},{"type":"EVENT_TYPE_TSUMOGIRI","who":2,"tile":70},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":92},{"type":"EVENT_TYPE_DRAW"},{"tile":105},{"type":"EVENT_TYPE_DRAW","who":1}]},"privateObservation":{"who":1,"initHand":{"closedTiles":[81,107,79,82,41,99,4,14,22,97,29,66,129]},"drawHistory":[16,39,64,122,30,98,119,15],"currHand":{"closedTiles":[15,30,39,97,98,99,119,122],"opens":[16631,43263]}},"legalActions":[{"type":"ACTION_TYPE_TSUMOGIRI","who":1,"tile":15},{"who":1,"tile":30},{"who":1,"tile":39},{"who":1,"tile":97},{"who":1,"tile":119},{"who":1,"tile":122}]}' +def test_can_riichi(): + json_str = '{"who":3,"publicObservation":{"playerIds":["player_2","player_0","player_1","player_3"],"initScore":{"round":7,"honba":7,"tens":[25000,25000,25000,25000]},"doraIndicators":[106],"events":[{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":71},{"type":"EVENT_TYPE_DRAW"},{"tile":42},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":129},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":67},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":60},{"type":"EVENT_TYPE_PON","open":23147},{"tile":107},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":14},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":116},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":8},{"type":"EVENT_TYPE_CHI","open":2167},{"tile":29},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":64},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":124},{"type":"EVENT_TYPE_DRAW","who":3},{"who":3,"tile":80},{"type":"EVENT_TYPE_DRAW"},{"tile":66},{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":41},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":32},{"type":"EVENT_TYPE_DRAW","who":3}]},"privateObservation":{"who":3,"initHand":{"closedTiles":[71,104,80,102,87,8,60,6,34,26,91,82,11]},"drawHistory":[78,23,79,15,18],"currHand":{"closedTiles":[6,11,15,18,23,26,34,78,79,82,87,91,102,104]}},"legalActions":[{"type":"ACTION_TYPE_RIICHI","who":3},{"who":3,"tile":6},{"who":3,"tile":11},{"who":3,"tile":15},{"type":"ACTION_TYPE_TSUMOGIRI","who":3,"tile":18},{"who":3,"tile":23},{"who":3,"tile":26},{"who":3,"tile":34},{"who":3,"tile":78},{"who":3,"tile":82},{"who":3,"tile":87},{"who":3,"tile":91},{"who":3,"tile":102},{"who":3,"tile":104}]}' obs = Observation(json_str) feature = obs.to_features("han22-v0") - # index2: If player 0 has >= 2 t in hand - # 24: 2p - assert feature[2][24] + # index89: If Riichi is possible by discarding t + # 8: 9m + assert feature[89][8] + + +def test_can_chi(): + json_str = '{"who":3,"publicObservation":{"playerIds":["player_3","player_2","player_0","player_1"],"initScore":{"round":1,"honba":1,"tens":[25000,25000,25000,25000]},"doraIndicators":[6],"events":[{"type":"EVENT_TYPE_DRAW","who":1},{"who":1,"tile":71},{"type":"EVENT_TYPE_DRAW","who":2},{"who":2,"tile":84}]},"privateObservation":{"who":3,"initHand":{"closedTiles":[50,20,29,39,111,76,62,88,2,112,120,94,97]},"currHand":{"closedTiles":[2,20,29,39,50,62,76,88,94,97,111,112,120]}},"legalActions":[{"type":"ACTION_TYPE_CHI","who":3,"open":52487},{"type":"ACTION_TYPE_NO","who":3}]}' + obs = Observation(json_str) + feature = obs.to_features("han22-v0") + + # index82: If at t can be chi(smallest) + # 21: 4s + assert feature[82][21] def test_kyuuhai(): @@ -139,7 +159,7 @@ def test_kyuuhai(): val = [ True, False, - True, + False, False, False, False, @@ -148,10 +168,10 @@ def test_kyuuhai(): True, False, False, - True, - True, False, - True, + False, + False, + False, False, False, True, @@ -162,7 +182,7 @@ def test_kyuuhai(): False, False, False, - True, + False, False, True, True,