Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

実験結果のまとめとnumpy配列への変換 #1148

Merged
merged 3 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions workspace/suphx-reward-shaping/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@


def evaluate_abs(
params: optax.Params, X, score, batch_size, use_logistic=False, use_clip=False
) -> float: # 前処理する前のスケールでの絶対誤差
params: optax.Params,
X,
score,
batch_size,
use_logistic=False,
use_clip=False,
) -> float: # 前処理する前のスケールでの絶対誤差
dataset = tf.data.Dataset.from_tensor_slices((X, score))
batched_dataset = dataset.batch(batch_size, drop_remainder=True)
cum_loss = 0
Expand All @@ -37,6 +42,9 @@ def eval_abs_loss(meth, _type, result_dir):
use_clip = False
if _type == "_use_clip_":
use_clip = True
if _type == "_no_logistic_after":
_type = "_no_logistic_"
use_clip = True
params_list: List = (
[
jnp.load(
Expand Down Expand Up @@ -75,15 +83,19 @@ def eval_abs_loss(meth, _type, result_dir):
32,
use_logistic=use_logistic,
use_clip=use_clip,
)
) # テストデータでの絶対誤差
abs_losses.append(float(np.array(abs_loss).item(0)))
return abs_losses


if __name__ == "__main__":

train_meth = ["suphx", "TD"]
types = ["_no_logistic_", "_use_logistic_", "_use_clip_"]
types = [
"_no_logistic_",
"_use_logistic_",
"_no_logistic__use_clip_",
"_no_logistic_after",
]

result_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "result")

Expand Down
1 change: 1 addition & 0 deletions workspace/suphx-reward-shaping/evaluate.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python evaluate.py
24 changes: 24 additions & 0 deletions workspace/suphx-reward-shaping/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

import jax
import jax.numpy as jnp
import numpy as np

"""
各局ごとにモデルが一つ存在する.
モデルのアーキテクチャは以下.
入力次元: 19
Layer1: 19 * 32
Activation: relu
Layer2: 32 * 32
Activation: relu
Layer3: 32 * 4
Clip(0, 1) game rewardを[0, 1]でnormalizeしているため.
"""


def predict(x: np.ndarray, W1: np.ndarray, W2: np.ndarray, W3: np.ndarray):
x = np.maximum(0, np.dot(x, W1))
x = np.maximum(0, np.dot(x, W2))
x = np.clip(np.dot(x, W3), a_min=0, a_max=1)
return x
57 changes: 57 additions & 0 deletions workspace/suphx-reward-shaping/tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
保存したnumpyの重みによる推論がjaxで学習した重みの推論と一致することを確認する.
"""

import os
import sys

import jax.numpy as jnp
import numpy as np

sys.path.append("../")
from inference import predict
from train_helper import net

numpy_dir = os.path.join(os.pardir, "weights/numpy")
jax_dir = os.path.join(os.pardir, "weights/jax")

jax_params = [
jnp.load(
os.path.join(jax_dir, "params_no_logistic_" + str(round) + ".pickle"), allow_pickle=True
)
for round in range(8)
]
numpy_params = [
[
np.load(
os.path.join(
numpy_dir, "weights_no_logistic_TD_" + str(round) + "_layer_" + str(layer) + ".npy"
),
allow_pickle=True,
)
for layer in range(3)
]
for round in range(8)
]


x_j: jnp.ndarray = jnp.array([1] * 19)
x_n: np.ndarray = np.array([1] * 19)

delta = 0.0001


def test_inference():
for i in range(8):
jax_param = jax_params[i]
numpy_param = numpy_params[i]
out_j = net(x_j, jax_param, use_clip=True)
out_n = predict(x_n, numpy_param[0], numpy_param[1], numpy_param[2])
assert out_j[0] - out_n[0] < delta
assert out_j[1] - out_n[1] < delta
assert out_j[2] - out_n[2] < delta
assert out_j[3] - out_n[3] < delta


if __name__ == "__main__":
test_inference()
21 changes: 21 additions & 0 deletions workspace/suphx-reward-shaping/to_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os

import jax
import jax.numpy as jnp
import numpy as np


def save_as_numpy(param_dir, round): # no logistic TDが一番性能良かったので, パラメータを保存する.
params = jnp.load(param_dir, allow_pickle=True)
for i, param in enumerate(params.values()):
jnp.save(
"numpy/weights_no_logistic_TD_" + str(round) + "_layer_" + str(i) + ".npy",
jnp.array(param),
)


if __name__ == "__main__":
result_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "result")
for round in range(8):
param_dir = os.path.join(result_dir, "params/params_no_logistic_" + str(round) + ".pickle")
save_as_numpy(param_dir, round)
43 changes: 21 additions & 22 deletions workspace/suphx-reward-shaping/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,27 +194,26 @@ def save_params(params, opt, result_dir):
args = parser.parse_args()
mjxproto_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.data_path)
result_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.result_path)
if args.train:
for lr in [0.01, 0.001]:
print(
"start_training, round_wise: {}, use_logistic: {}, target_round: {}".format(
args.round_wise, args.use_logistic, args.target_round
)
for lr in [0.01, 0.001]:
print(
"start_training, round_wise: {}, use_logistic: {}, target_round: {}".format(
args.round_wise, args.use_logistic, args.target_round
)
if args.round_wise:
X, Y, scores = set_dataset_round_wise(mjxproto_dir, result_dir, args)
else:
X, Y, scores = set_dataset_whole(mjxproto_dir, result_dir, args)
params, train_log, val_log = run_training(X, Y, scores, args, lr)
save_params(params, args, result_dir)
save_learning_log(train_log, val_log, args, result_dir, lr)
plot_learning_log(train_log, val_log, args, result_dir, lr)
"""
if args.round_wise == 0:
for round in range(8):
for i in range(4):
plot_result(params, i, round, args, result_dir)
else:
)
if args.round_wise:
X, Y, scores = set_dataset_round_wise(mjxproto_dir, result_dir, args)
else:
X, Y, scores = set_dataset_whole(mjxproto_dir, result_dir, args)
params, train_log, val_log = run_training(X, Y, scores, args, lr)
save_params(params, args, result_dir)
save_learning_log(train_log, val_log, args, result_dir, lr)
plot_learning_log(train_log, val_log, args, result_dir, lr)
"""
if args.round_wise == 0:
for round in range(8):
for i in range(4):
plot_result(params, i, args.target_round, args, result_dir)
"""
plot_result(params, i, round, args, result_dir)
else:
for i in range(4):
plot_result(params, i, args.target_round, args, result_dir)
"""
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.