From 1a120f748151b64becd41448cd43165e23302cc2 Mon Sep 17 00:00:00 2001 From: Nishimori Soichiro <55940613+nissymori@users.noreply.github.com> Date: Fri, 16 Sep 2022 19:26:48 +0900 Subject: [PATCH] Train reward-shaping model with large data (#1134) * implement train.py * fix feature * fix * fix * fix * add readme * fix * fix features * fix * fix * fix typo * fix --- .gitignore | 7 +- workspace/.DS_Store | Bin 6148 -> 6148 bytes .../tests/test_train_helper.py | 40 ---- .../suphnx-reward-shaping/tests/test_utils.py | 20 -- workspace/suphnx-reward-shaping/train.py | 41 ---- .../suphnx-reward-shaping/train_helper.py | 101 ---------- workspace/suphnx-reward-shaping/utils.py | 78 -------- workspace/suphx-reward-shaping/README.md | 30 +++ .../2022060100gm-00a9-0000-3e8b8aaf.json | 0 .../2022060100gm-00a9-0000-3ffa4858.json | 0 .../2022060100gm-00a9-0000-6db179be.json | 0 .../2022060100gm-00a9-0000-7c8869db.json | 0 .../tests/test_train_helper.py | 66 +++++++ .../suphx-reward-shaping/tests/test_utils.py | 49 +++++ workspace/suphx-reward-shaping/train.py | 76 ++++++++ .../suphx-reward-shaping/train_helper.py | 180 ++++++++++++++++++ workspace/suphx-reward-shaping/utils.py | 143 ++++++++++++++ 17 files changed, 549 insertions(+), 282 deletions(-) delete mode 100644 workspace/suphnx-reward-shaping/tests/test_train_helper.py delete mode 100644 workspace/suphnx-reward-shaping/tests/test_utils.py delete mode 100644 workspace/suphnx-reward-shaping/train.py delete mode 100644 workspace/suphnx-reward-shaping/train_helper.py delete mode 100644 workspace/suphnx-reward-shaping/utils.py create mode 100644 workspace/suphx-reward-shaping/README.md rename workspace/{suphnx-reward-shaping => suphx-reward-shaping}/tests/resources/2022060100gm-00a9-0000-3e8b8aaf.json (100%) rename workspace/{suphnx-reward-shaping => suphx-reward-shaping}/tests/resources/2022060100gm-00a9-0000-3ffa4858.json (100%) rename workspace/{suphnx-reward-shaping => suphx-reward-shaping}/tests/resources/2022060100gm-00a9-0000-6db179be.json (100%) rename workspace/{suphnx-reward-shaping => suphx-reward-shaping}/tests/resources/2022060100gm-00a9-0000-7c8869db.json (100%) create mode 100644 workspace/suphx-reward-shaping/tests/test_train_helper.py create mode 100644 workspace/suphx-reward-shaping/tests/test_utils.py create mode 100644 workspace/suphx-reward-shaping/train.py create mode 100644 workspace/suphx-reward-shaping/train_helper.py create mode 100644 workspace/suphx-reward-shaping/utils.py diff --git a/.gitignore b/.gitignore index 150c879a..751f4b68 100644 --- a/.gitignore +++ b/.gitignore @@ -20,8 +20,11 @@ mjx-py/.vscode/* dist .pytest_cache .cache + .ipynb_checkpoints -workspace/suphnx-reward-shaping/resources/* +workspace/suphx-reward-shaping/resources/* +workspace/suphx-reward-shaping/trained_model/* +workspace/suphx-reward-shaping/result/* .DS_Store .vscode/ -.python_versions +.python_version diff --git a/workspace/.DS_Store b/workspace/.DS_Store index 304a472acf898c08aaed3fe1a50add69c96672e5..13486b380e70182d3e804578f1c12d2f327a7ee3 100644 GIT binary patch delta 164 zcmZoMXfc=|#>B!ku~2NHo+2ar#(>?7iyN4k7}+QPWzySh#iYR~kdtm0oSdIqzyJj4 zPxB){G9tP8E-pzq`AI-Aj(#CK?Nyma9H9~^SS17sG7!dOKM!C4Tad)OlyNgV2R{eU b4Vw#@zcWwf7jfiZWME(d*|s@CWDPR_!=Nt& delta 71 zcmZoMXfc=|#>B)qu~2NHo+2aj#(>?7jLeh&vgmELV$opSoW!=2abW}VW_AvK4xqBl Zf*jwOC-aLqaxee^BLf4=<_M8B%mAo65f}gf diff --git a/workspace/suphnx-reward-shaping/tests/test_train_helper.py b/workspace/suphnx-reward-shaping/tests/test_train_helper.py deleted file mode 100644 index 47326c2e..00000000 --- a/workspace/suphnx-reward-shaping/tests/test_train_helper.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -import sys - -import jax -import jax.numpy as jnp -import optax - -sys.path.append("../") -from train_helper import evaluate, initializa_params, train -from utils import to_data - -layer_sizes = [3, 4, 5, 1] -feature_size = 6 -seed = jax.random.PRNGKey(42) - -mjxprotp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources") - - -def test_initialize_params(): - params = initializa_params(layer_sizes, feature_size, seed) - assert len(params) == 4 - - -def test_train(): - params = initializa_params(layer_sizes, feature_size, seed) - featurs, scores = to_data(mjxprotp_dir) - optimizer = optax.adam(0.05) - params = train(params, optimizer, featurs, scores, epochs=1, batch_size=1) - assert len(params) == 4 - - -def test_evaluate(): - params = initializa_params(layer_sizes, feature_size, seed) - featurs, scores = to_data(mjxprotp_dir) - loss = evaluate(params, featurs, scores, batch_size=2) - assert loss >= 0 - - -if __name__ == "__main__": - test_train() diff --git a/workspace/suphnx-reward-shaping/tests/test_utils.py b/workspace/suphnx-reward-shaping/tests/test_utils.py deleted file mode 100644 index 0a6f7cfa..00000000 --- a/workspace/suphnx-reward-shaping/tests/test_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import json -import os -import sys - -from google.protobuf import json_format - -sys.path.append("../../../") -import mjxproto - -sys.path.append("../") -from utils import to_data - -mjxprotp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources") - - -def test_to_dataset(): - num_resources = len(os.listdir(mjxprotp_dir)) - features, scores = to_data(mjxprotp_dir) - assert features.shape == (num_resources, 6) - assert scores.shape == (num_resources, 1) diff --git a/workspace/suphnx-reward-shaping/train.py b/workspace/suphnx-reward-shaping/train.py deleted file mode 100644 index c3134555..00000000 --- a/workspace/suphnx-reward-shaping/train.py +++ /dev/null @@ -1,41 +0,0 @@ -import argparse -import math -import os -import sys - -import jax -import jax.numpy as jnp -import optax -from train_helper import evaluate, initializa_params, train -from utils import normalize, to_data - -mjxprotp_dir = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "resources/mjxproto" -) # please specify your mjxproto dir - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("lr", help="Enter learning rate", type=float) - parser.add_argument("epochs", help="Enter epochs", type=int) - parser.add_argument("batch_size", help="Enter batch_size", type=int) - - args = parser.parse_args() - - X, Y = to_data(mjxprotp_dir) - X = normalize(X) - Y = normalize(Y) - - train_x = X[: math.floor(len(X) * 0.8)] - train_y = Y[: math.floor(len(X) * 0.8)] - test_x = X[math.floor(len(X) * 0.8) :] - test_y = Y[math.floor(len(X) * 0.8) :] - - layer_size = [32, 32, 1] - seed = jax.random.PRNGKey(42) - - params = initializa_params(layer_size, 6, seed) - optimizer = optax.adam(learning_rate=args.lr) - - params = train(params, optimizer, train_x, train_y, args.epochs, args.batch_size) - - print(evaluate(params, test_x, test_y, args.batch_size)) diff --git a/workspace/suphnx-reward-shaping/train_helper.py b/workspace/suphnx-reward-shaping/train_helper.py deleted file mode 100644 index f2618001..00000000 --- a/workspace/suphnx-reward-shaping/train_helper.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import Dict, List - -import jax -import jax.nn as nn -import jax.numpy as jnp -import numpy as np -import optax -import tensorflow as tf -from jax import grad -from jax import jit as jjit -from jax import numpy as jnp -from jax import value_and_grad, vmap - - -def initializa_params(layer_sizes: List[int], features: int, seed) -> Dict: - """ - 重みを初期化する関数. 線形層を前提としている. - Xavier initializationを採用 - """ - params = {} - - for i, units in enumerate(layer_sizes): - if i == 0: - w = jax.random.uniform( - key=seed, - shape=(features, units), - minval=-np.sqrt(6) / np.sqrt(units), - maxval=-np.sqrt(6) / np.sqrt(units), - dtype=jnp.float32, - ) - else: - w = jax.random.uniform( - key=seed, - shape=(layer_sizes[i - 1], units), - minval=-np.sqrt(6) / np.sqrt(units + layer_sizes[i - 1]), - maxval=np.sqrt(6) / np.sqrt(units + layer_sizes[i - 1]), - dtype=jnp.float32, - ) - params["linear" + str(i)] = w - return params - - -def relu(x: jnp.ndarray) -> jnp.ndarray: - return jnp.maximum(0, x) - - -def net(x: jnp.ndarray, params: optax.Params) -> jnp.ndarray: - for k, param in params.items(): - x = jnp.dot(x, param) - x = jax.nn.relu(x) - return x - - -def loss(params: optax.Params, batched_x: jnp.ndarray, batched_y: jnp.ndarray) -> jnp.ndarray: - preds = net(batched_x, params) - loss_value = optax.l2_loss(preds, batched_y).sum(axis=-1) - return loss_value.mean() - - -def train( - params: optax.Params, - optimizer: optax.GradientTransformation, - X: jnp.ndarray, - Y: jnp.ndarray, - epochs: int, - batch_size: int, - buffer_size=3, -) -> optax.Params: - """ - 学習用の関数. 線形層を前提としており, バッチ処理やシャッフルのためにtensorflowを使っている. - """ - dataset = tf.data.Dataset.from_tensor_slices((X, Y)) - batched_dataset = dataset.shuffle(buffer_size=buffer_size).batch( - batch_size, drop_remainder=True - ) - opt_state = optimizer.init(params) - - @jax.jit - def step(params, opt_state, batch, labels): - loss_value, grads = jax.value_and_grad(loss)(params, batch, labels) - updates, opt_state = optimizer.update(grads, opt_state, params) - params = optax.apply_updates(params, updates) - return params, opt_state, loss_value - - for i in range(epochs): - for batched_x, batched_y in batched_dataset: - params, opt_state, loss_value = step( - params, opt_state, batched_x.numpy(), batched_y.numpy() - ) - if i % 100 == 0: # print MSE every 100 epochs - print(f"step {i}, loss: {loss_value}") - return params - - -def evaluate(params: optax.Params, X: jnp.ndarray, Y: jnp.ndarray, batch_size: int) -> float: - dataset = tf.data.Dataset.from_tensor_slices((X, Y)) - batched_dataset = dataset.batch(batch_size, drop_remainder=True) - cum_loss = 0 - for batch_x, batch_y in batched_dataset: - cum_loss += loss(params, batch_x.numpy(), batch_y.numpy()) - return cum_loss / len(batched_dataset) diff --git a/workspace/suphnx-reward-shaping/utils.py b/workspace/suphnx-reward-shaping/utils.py deleted file mode 100644 index 7ac64b26..00000000 --- a/workspace/suphnx-reward-shaping/utils.py +++ /dev/null @@ -1,78 +0,0 @@ -import json -import os -import random -import sys -from typing import Dict, Iterator, List, Optional, Tuple - -import jax -import jax.numpy as jnp -import numpy as np -from google.protobuf import json_format - -sys.path.append("../../") -sys.path.append("../../../") -import mjxproto - -oka = [90, 40, 0, -130] - - -def to_data(mjxprotp_dir: str) -> Tuple[jnp.ndarray, jnp.ndarray]: - """ - jsonが入っているディレクトリを引数としてjax.numpyのデータセットを作る. - """ - features: List = [] - scores: List = [] - for _json in os.listdir(mjxprotp_dir): - _json = os.path.join(mjxprotp_dir, _json) - assert ".json" in _json - with open(_json, "r") as f: - lines = f.readlines() - _dicts = [json.loads(round) for round in lines] - states = [json_format.ParseDict(d, mjxproto.State()) for d in _dicts] - target: int = random.randint(0, 3) - features.append(to_feature(states, target)) - scores.append(to_final_scores(states, target)) - features_array: jnp.ndarray = jnp.array(features) - scores_array: jnp.ndarray = jnp.array(scores) - return features_array, scores_array - - -def normalize(array: jnp.ndarray) -> jnp.ndarray: - mean = array.mean(axis=0) - std = array.mean(axis=0) - return (array - mean) / std - - -def _select_one_round(states: List[mjxproto.State]) -> mjxproto.State: - """ - データセットに本質的で無い相関が生まれることを防ぐために一半荘につき1ペアのみを使う. - """ - idx: int = random.randint(0, len(states) - 1) - return states[idx] - - -def _calc_curr_pos(init_pos: int, round: int) -> int: - return init_pos + round % 4 - - -def to_feature(states: List[mjxproto.State], target) -> List[int]: - """ - 特徴量 = [終了時の点数, 自風, 親, 局, 本場, 詰み棒] - """ - state = _select_one_round(states) - ten: int = state.round_terminal.final_score.tens[target] - honba: int = state.round_terminal.final_score.honba - tsumibo: int = state.round_terminal.final_score.riichi - round: int = state.round_terminal.final_score.round - wind: int = _calc_curr_pos(target, round) - oya: int = _calc_curr_pos(0, round) - return [ten, honba, tsumibo, round, wind, oya] - - -def to_final_scores(states: List[mjxproto.State], target) -> List[int]: - final_state = states[-1] - final_scores = final_state.round_terminal.final_score.tens - target_score = final_scores[target] - sorted_scores = sorted(final_scores) - rank = sorted_scores.index(target_score) - return [oka[rank]] diff --git a/workspace/suphx-reward-shaping/README.md b/workspace/suphx-reward-shaping/README.md new file mode 100644 index 00000000..53ef36ea --- /dev/null +++ b/workspace/suphx-reward-shaping/README.md @@ -0,0 +1,30 @@ +## Suphnx-like reward shaping + +## How to train the model + +Prepare the directories for data and result under this directory. After that, we can train the model thorough cli. + +``` +$python train.py 0.001 10 16 --use_saved_data 0 --data_path resources/mjxproto --result_path result. +``` + +Here is the information about argument. + +The first three are learning rate, epochs, batch size respectively. + +`--use_saved_data` 0 means not to use saved data and other than 0 means otherwise. The default is 0. + +`--round_candidates` We can specify rounds to use for training by this argument. + +`--data_path` Please specify the data path. + +`--result_path` Please specify the result path. + + + + + + + + + diff --git a/workspace/suphnx-reward-shaping/tests/resources/2022060100gm-00a9-0000-3e8b8aaf.json b/workspace/suphx-reward-shaping/tests/resources/2022060100gm-00a9-0000-3e8b8aaf.json similarity index 100% rename from workspace/suphnx-reward-shaping/tests/resources/2022060100gm-00a9-0000-3e8b8aaf.json rename to workspace/suphx-reward-shaping/tests/resources/2022060100gm-00a9-0000-3e8b8aaf.json diff --git a/workspace/suphnx-reward-shaping/tests/resources/2022060100gm-00a9-0000-3ffa4858.json b/workspace/suphx-reward-shaping/tests/resources/2022060100gm-00a9-0000-3ffa4858.json similarity index 100% rename from workspace/suphnx-reward-shaping/tests/resources/2022060100gm-00a9-0000-3ffa4858.json rename to workspace/suphx-reward-shaping/tests/resources/2022060100gm-00a9-0000-3ffa4858.json diff --git a/workspace/suphnx-reward-shaping/tests/resources/2022060100gm-00a9-0000-6db179be.json b/workspace/suphx-reward-shaping/tests/resources/2022060100gm-00a9-0000-6db179be.json similarity index 100% rename from workspace/suphnx-reward-shaping/tests/resources/2022060100gm-00a9-0000-6db179be.json rename to workspace/suphx-reward-shaping/tests/resources/2022060100gm-00a9-0000-6db179be.json diff --git a/workspace/suphnx-reward-shaping/tests/resources/2022060100gm-00a9-0000-7c8869db.json b/workspace/suphx-reward-shaping/tests/resources/2022060100gm-00a9-0000-7c8869db.json similarity index 100% rename from workspace/suphnx-reward-shaping/tests/resources/2022060100gm-00a9-0000-7c8869db.json rename to workspace/suphx-reward-shaping/tests/resources/2022060100gm-00a9-0000-7c8869db.json diff --git a/workspace/suphx-reward-shaping/tests/test_train_helper.py b/workspace/suphx-reward-shaping/tests/test_train_helper.py new file mode 100644 index 00000000..c7260aa4 --- /dev/null +++ b/workspace/suphx-reward-shaping/tests/test_train_helper.py @@ -0,0 +1,66 @@ +import os +import sys + +import jax +import jax.numpy as jnp +import optax + +sys.path.append("../") +from train_helper import initializa_params, loss, net, plot_result, save_params, train +from utils import to_data + +layer_sizes = [3, 4, 5, 4] +feature_size = 19 +seed = jax.random.PRNGKey(42) +save_dir = os.path.join(os.pardir, "trained_model/test_param.pickle") +result_dir = os.path.join(os.pardir, "result") + +mjxprotp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources") + + +def test_initialize_params(): + params = initializa_params(layer_sizes, feature_size, seed) + assert len(params) == 4 + + +def test_train(): + params = initializa_params(layer_sizes, feature_size, seed) + features, scores = to_data(mjxprotp_dir) + optimizer = optax.adam(0.05) + params, train_log, test_log = train( + params, optimizer, features, scores, features, scores, epochs=1, batch_size=1 + ) + assert len(params) == 4 + + +def test_save_model(): + params = initializa_params(layer_sizes, feature_size, seed) + features, scores = to_data(mjxprotp_dir) + optimizer = optax.adam(0.05) + params = train(params, optimizer, features, scores, features, scores, epochs=1, batch_size=1) + save_params(params, save_dir) + + +def test_plot_result(): + params = initializa_params(layer_sizes, feature_size, seed) + features, scores = to_data(mjxprotp_dir) + optimizer = optax.adam(0.05) + params = train(params, optimizer, features, scores, features, scores, epochs=1, batch_size=1) + plot_result(params, features, scores, result_dir) + + +def test_net(): + params = initializa_params(layer_sizes, feature_size, seed) + features, scores = to_data(mjxprotp_dir) + print(net(features[0], params), features, params) + + +def test_loss(): + params = initializa_params(layer_sizes, feature_size, seed) + features, scores = to_data(mjxprotp_dir) + print(loss(params, features, scores)) + + +if __name__ == "__main__": + test_net() + test_loss() diff --git a/workspace/suphx-reward-shaping/tests/test_utils.py b/workspace/suphx-reward-shaping/tests/test_utils.py new file mode 100644 index 00000000..e61a98d6 --- /dev/null +++ b/workspace/suphx-reward-shaping/tests/test_utils.py @@ -0,0 +1,49 @@ +import json +import os +import sys + +from google.protobuf import json_format + +sys.path.append("../../../") +import mjxproto + +sys.path.append("../") +from utils import _calc_wind, _preprocess_scores, to_data, to_final_game_reward + +mjxprotp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources") + + +def test_preprocess(): + scores = [0, 100000, 200000, 300000] + print(_preprocess_scores(scores, 1)) + assert _preprocess_scores(scores, 0) == [0, 3, 2, 1] + assert _preprocess_scores(scores, 1) == [1, 0, 3, 2] + assert _preprocess_scores(scores, 2) == [2, 1, 0, 3] + assert _preprocess_scores(scores, 3) == [3, 2, 1, 0] + + +def test_calc_wind(): + assert _calc_wind(1, 0) == 1 + assert _calc_wind(1, 3) == 2 + + +def test_to_final_game_reward(): + _dir = os.path.join(mjxprotp_dir, os.listdir(mjxprotp_dir)[0]) + with open(_dir, "r") as f: + lines = f.readlines() + _dicts = [json.loads(round) for round in lines] + states = [json_format.ParseDict(d, mjxproto.State()) for d in _dicts] + assert to_final_game_reward(states) == [0.9, 0.0, -1.35, 0.45] + + +def test_to_data(): + num_resources = len(os.listdir(mjxprotp_dir)) + features, scores = to_data(mjxprotp_dir) + assert features.shape == (num_resources, 19) + assert scores.shape == (num_resources, 4) + + +if __name__ == "__main__": + test_to_data() + test_to_final_game_reward() + test_calc_wind() diff --git a/workspace/suphx-reward-shaping/train.py b/workspace/suphx-reward-shaping/train.py new file mode 100644 index 00000000..30b44242 --- /dev/null +++ b/workspace/suphx-reward-shaping/train.py @@ -0,0 +1,76 @@ +import argparse +import math +import os +import pickle + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import optax +from train_helper import initializa_params, plot_result, save_params, train +from utils import to_data + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("lr", help="Enter learning rate", type=float) + parser.add_argument("epochs", help="Enter epochs", type=int) + parser.add_argument("batch_size", help="Enter batch_size", type=int) + parser.add_argument("is_round_one_hot", nargs="?", default="0") + parser.add_argument("--use_saved_data", nargs="?", default="0") + parser.add_argument("--round_candidates", type=int, default=None) + parser.add_argument("--data_path", default="resources/mjxproto") + parser.add_argument("--result_path", default="result") + + args = parser.parse_args() + + mjxprotp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.data_path) + + result_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.result_path) + + if args.use_saved_data == "0": + X, Y = to_data(mjxprotp_dir, round_candidates=[args.round_candidates]) + if args.round_candidates: + jnp.save(os.path.join(result_dir, "features" + str(args.round_candidates)), X) + jnp.save(os.path.join(result_dir, "labels" + str(args.round_candidates)), Y) + else: + jnp.save(os.path.join(result_dir, "features"), X) + jnp.save(os.path.join(result_dir, "labels"), Y) + else: + if args.round_candidates: + X: jnp.ndarray = jnp.load( + os.path.join(result_dir, "features" + str(args.round_candidates) + ".npy") + ) + Y: jnp.ndarray = jnp.load( + os.path.join(result_dir, "labels" + str(args.round_candidates) + ".npy") + ) + else: + X: jnp.ndarray = jnp.load(os.path.join(result_dir, "features.npy")) + Y: jnp.ndarray = jnp.load(os.path.join(result_dir, "labels.npy")) + + train_x = X[: math.floor(len(X) * 0.8)] + train_y = Y[: math.floor(len(X) * 0.8)] + test_x = X[math.floor(len(X) * 0.8) :] + test_y = Y[math.floor(len(X) * 0.8) :] + + layer_size = [32, 32, 4] + seed = jax.random.PRNGKey(42) + + if args.is_round_one_hot == "0": + params = initializa_params(layer_size, 19, seed) + else: + params = initializa_params(layer_size, 26, seed) # featureでroundがone-hotになっている. + + optimizer = optax.adam(learning_rate=args.lr) + + params, train_log, test_log = train( + params, optimizer, train_x, train_y, test_x, test_y, args.epochs, args.batch_size + ) + + plt.plot(train_log, label="train") + plt.plot(test_log, label="val") + plt.legend() + plt.savefig(os.path.join(result_dir, "log/leaning_curve.png")) + + save_params(params, result_dir) + + plot_result(params, X, Y, result_dir, round_candidates=[args.round_candidates]) diff --git a/workspace/suphx-reward-shaping/train_helper.py b/workspace/suphx-reward-shaping/train_helper.py new file mode 100644 index 00000000..437de3a5 --- /dev/null +++ b/workspace/suphx-reward-shaping/train_helper.py @@ -0,0 +1,180 @@ +import os +import pickle +from cProfile import label +from re import I +from typing import Dict, List, Optional + +import jax +import jax.nn as nn +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax +import tensorflow as tf +from jax import grad +from jax import jit as jjit +from jax import numpy as jnp +from jax import value_and_grad, vmap + + +def initializa_params(layer_sizes: List[int], features: int, seed) -> Dict: + """ + 重みを初期化する関数. 線形層を前提としている. + Xavier initializationを採用 + """ + params = {} + + for i, units in enumerate(layer_sizes): + if i == 0: + w = jax.random.uniform( + key=seed, + shape=(features, units), + minval=-np.sqrt(6) / np.sqrt(units), + maxval=np.sqrt(6) / np.sqrt(units), + dtype=jnp.float32, + ) + else: + w = jax.random.uniform( + key=seed, + shape=(layer_sizes[i - 1], units), + minval=-np.sqrt(6) / np.sqrt(units + layer_sizes[i - 1]), + maxval=np.sqrt(6) / np.sqrt(units + layer_sizes[i - 1]), + dtype=jnp.float32, + ) + params["linear" + str(i)] = w + return params + + +def relu(x: jnp.ndarray) -> jnp.ndarray: + return jnp.maximum(0, x) + + +def net(x: jnp.ndarray, params: optax.Params) -> jnp.ndarray: + for i, param in enumerate(params.values()): + x = jnp.dot(x, param) + if i + 1 < len(params.values()): + x = jax.nn.relu(x) + return x + + +def loss(params: optax.Params, batched_x: jnp.ndarray, batched_y: jnp.ndarray) -> jnp.ndarray: + preds = net(batched_x, params) + loss_value = optax.l2_loss(preds, batched_y) + return loss_value.mean() + + +def train_one_step(params: optax.Params, opt_state, batched_dataset, optimizer, epoch): + @jax.jit + def step(params: optax.Params, opt_state, batch, labels): + loss_value, grads = jax.value_and_grad(loss)(params, batch, labels) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state, loss_value + + cum_loss = 0 + for batched_x, batched_y in batched_dataset: + params, opt_state, loss_value = step( + params, opt_state, batched_x.numpy(), batched_y.numpy(), optimizer + ) + cum_loss += loss_value + if epoch % 100 == 0: # print MSE every 100 epochs + pred = net(batched_x[0].numpy(), params) + print(f"step {epoch}, pred {pred}, actual {batched_y[0]}") + return params, cum_loss / len(batched_dataset) + + +def evaluate_one_step(params: optax.Params, batched_dataset) -> float: + cum_loss = 0 + for batched_x, batched_y in batched_dataset: + cum_loss += loss(params, batched_x.numpy(), batched_y.numpy()) + return cum_loss / len(batched_dataset) + + +def train( + params: optax.Params, + optimizer: optax.GradientTransformation, + X_train: jnp.ndarray, + Y_train: jnp.ndarray, + X_test: jnp.ndarray, + Y_test: jnp.ndarray, + epochs: int, + batch_size: int, + buffer_size=3, +) -> optax.Params: + """ + 学習用の関数. 線形層を前提としており, バッチ処理やシャッフルのためにtensorflowを使っている. + """ + dataset_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train)) + batched_dataset_train = dataset_train.shuffle(buffer_size=buffer_size).batch( + batch_size, drop_remainder=True + ) + dataset_test = tf.data.Dataset.from_tensor_slices((X_test, Y_test)) + batched_dataset_test = dataset_test.batch(batch_size, drop_remainder=True) + opt_state = optimizer.init(params) + + train_log, test_log = [], [] + + @jax.jit + def step(params, opt_state, batch, labels): + loss_value, grads = jax.value_and_grad(loss)(params, batch, labels) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state, loss_value + + for i in range(epochs): + cum_loss = 0 + for batched_x, batched_y in batched_dataset_train: + params, opt_state, loss_value = step( + params, opt_state, batched_x.numpy(), batched_y.numpy() + ) + cum_loss += loss_value + if i % 100 == 0: # print MSE every 100 epochs + pred = net(batched_x[0].numpy(), params) + print(f"step {i}, loss: {loss_value}, pred {pred}, actual {batched_y[0]}") + mean_train_loss = cum_loss / len(batched_dataset_train) + + mean_test_loss = evaluate_one_step(params, batched_dataset_test) + + # record mean of train loss and test loss per epoch + train_log.append(float(np.array(mean_train_loss).item(0))) + test_log.append(float(np.array(mean_test_loss).item(0))) + return params, train_log, test_log + + +def save_params(params: optax.Params, save_dir): + with open(save_dir + "params.pickle", "wb") as f: + pickle.dump(params, f) + + +def plot_result( + params: optax.Params, X, Y, result_dir, is_round_one_hot=False, round_candidates=None +): + fig = plt.figure(figsize=(10, 5)) + axes = fig.subplots(1, 2) + if not round_candidates: + round_candidates = [i for i in range(8)] + for i in round_candidates: # 通常の局数分 + log_score = [] + log_pred = [] + for j in range(60): + x = jnp.array(_create_data_for_plot(j * 1000, i, is_round_one_hot)) + pred = net(x, params) + log_score.append(j * 1000) + log_pred.append(pred * 100) + axes[0].plot(log_score, log_pred, label="round_" + str(i)) + axes[1].plot(log_score, log_pred, ".", label="round_" + str(i)) + plt.legend() + save_dir = os.path.join(result_dir, "prediction_at_round" + str(i) + ".png") + plt.savefig(save_dir) + + +def _create_data_for_plot(score, round, is_round_one_hot) -> List: + scores = [score / 100000] + [(100000 - score) / 300000] * 3 + wind = [1, 0, 0, 0] + oya = [1, 0, 0, 0] + if is_round_one_hot: + rounds = [0] * 8 + rounds[round] = 1 + return scores + wind + oya + rounds + [0, 0] + else: + return scores + wind + oya + [round / 7, 0, 0] diff --git a/workspace/suphx-reward-shaping/utils.py b/workspace/suphx-reward-shaping/utils.py new file mode 100644 index 00000000..9c51c644 --- /dev/null +++ b/workspace/suphx-reward-shaping/utils.py @@ -0,0 +1,143 @@ +import json +import os +import random +import sys +from typing import Dict, Iterator, List, Optional, Tuple + +import jax +import jax.numpy as jnp +import numpy as np +from google.protobuf import json_format + +sys.path.append("../../") +sys.path.append("../../../") +import mjxproto + +game_rewards = [90, 45, 0, -135] + + +def to_data( + mjxprotp_dir: str, round_candidates: Optional[List[int]] = None, model=None, use_model=False +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + jsonが入っているディレクトリを引数としてjax.numpyのデータセットを作る. + """ + features: List = [] + scores: List = [] + for _json in os.listdir(mjxprotp_dir): + _json = os.path.join(mjxprotp_dir, _json) + assert ".json" in _json + with open(_json, "r") as f: + lines = f.readlines() + _dicts = [json.loads(round) for round in lines] + states = [json_format.ParseDict(d, mjxproto.State()) for d in _dicts] + features = to_feature(states, round_candidates=round_candidates) + features.append(features) + if use_model: + scores.append(model(jnp.array(features))) + else: + scores.append(to_final_game_reward(states)) + features_array: jnp.ndarray = jnp.array(features) + scores_array: jnp.ndarray = jnp.array(scores) + return features_array, scores_array + + +def _select_one_round( + states: List[mjxproto.State], candidates: Optional[List[int]] = None +) -> mjxproto.State: + """ + データセットに本質的で無い相関が生まれることを防ぐために一半荘につき1ペアのみを使う. + """ + if candidates: + if min(candidates) > len(states) - 1: # 候補のと対応する局がない場合, 一番近いものを返す. + return states[len(states) - 1] + idx = random.choice(candidates) + return states[idx] + else: + idx: int = random.randint(0, len(states) - 1) + return states[idx] + + +def _calc_curr_pos(init_pos: int, round: int) -> int: + pos = (-init_pos + round) % 4 + assert 0 <= pos <= 3 + return pos + + +def _calc_wind(init_pos: int, round: int) -> int: + pos = (-init_pos + round) % 4 + if pos == 1: + return 3 + if pos == 3: + return 1 + return pos + + +def _to_one_hot(total_num: int, idx: int) -> List[int]: + _l = [0] * total_num + _l[idx] = 1 + return _l + + +def _clip_round(round: int, lim=7) -> int: + """ + 天鳳ではてんほうでは最長西4局まで行われるが何四局以降はサドンデスなので同一視. + """ + if round < 7: + return round + else: + return 7 + + +def _preprocess_scores(scores, target: int) -> List: + """ + 局終了時の点数を100000で割って自家, 下家, 対面, 上家の順に並び替える. + """ + _self: int = scores[target] / 100000 + _left: int = scores[target - 1] / 100000 + _front: int = scores[target - 2] / 100000 + _right: int = scores[target - 3] / 100000 + return [_self, _left, _front, _right] + + +def _remaining_oya(round: int): # 局終了時の残りの親の数 + return [2 - (round // 4 + ((round % 4) >= i)) for i in range(4)] + + +def to_feature( + states: List[mjxproto.State], + is_round_one_hot=False, + round_candidates: Optional[List[int]] = None, +) -> List: + """ + 特徴量 = [4playerの点数, 起家の風:one-hot, 親:one-hot, 残りの親の数, 局, 本場, 詰み棒] + """ + state = _select_one_round(states, candidates=round_candidates) + scores: List = [i / 100000 for i in state.round_terminal.final_score.tens] + honba: int = state.round_terminal.final_score.honba + tsumibo: int = state.round_terminal.final_score.riichi + round: int = _clip_round(state.round_terminal.final_score.round) + wind: List[int] = _to_one_hot(4, _calc_wind(0, round)) # 起家の風のみを入力 + oya: List[int] = _to_one_hot(4, _calc_curr_pos(0, round)) + remainning_oya = _remaining_oya(round) + if is_round_one_hot: + one_hot_round: List[int] = _to_one_hot(8, round) + feature = ( + scores + wind + oya + remainning_oya + one_hot_round + [honba / 4, tsumibo / 4] + ) # len(feature) = 26 + else: + feature = ( + scores + wind + oya + remainning_oya + [round / 7, honba / 4, tsumibo / 4] + ) # len(feature) = 19 + return feature + + +def to_final_game_reward(states: List[mjxproto.State]) -> List: + """ + 順位点. 起家から順番に. 4次元. + """ + final_state = states[-1] + final_scores = final_state.round_terminal.final_score.tens + sorted_scores = sorted(final_scores, reverse=True) + ranks = [sorted_scores.index(final_scores[i]) for i in range(4)] + return [game_rewards[i] / 100 for i in ranks]