Skip to content

Commit

Permalink
Train reward-shaping model with large data (#1134)
Browse files Browse the repository at this point in the history
* implement train.py

* fix feature

* fix

* fix

* fix

* add readme

* fix

* fix features

* fix

* fix

* fix typo

* fix
  • Loading branch information
nissymori authored Sep 16, 2022
1 parent 0246b1f commit 1a120f7
Show file tree
Hide file tree
Showing 17 changed files with 549 additions and 282 deletions.
7 changes: 5 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ mjx-py/.vscode/*
dist
.pytest_cache
.cache

.ipynb_checkpoints
workspace/suphnx-reward-shaping/resources/*
workspace/suphx-reward-shaping/resources/*
workspace/suphx-reward-shaping/trained_model/*
workspace/suphx-reward-shaping/result/*
.DS_Store
.vscode/
.python_versions
.python_version
Binary file modified workspace/.DS_Store
Binary file not shown.
40 changes: 0 additions & 40 deletions workspace/suphnx-reward-shaping/tests/test_train_helper.py

This file was deleted.

20 changes: 0 additions & 20 deletions workspace/suphnx-reward-shaping/tests/test_utils.py

This file was deleted.

41 changes: 0 additions & 41 deletions workspace/suphnx-reward-shaping/train.py

This file was deleted.

101 changes: 0 additions & 101 deletions workspace/suphnx-reward-shaping/train_helper.py

This file was deleted.

78 changes: 0 additions & 78 deletions workspace/suphnx-reward-shaping/utils.py

This file was deleted.

30 changes: 30 additions & 0 deletions workspace/suphx-reward-shaping/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## Suphnx-like reward shaping

## How to train the model

Prepare the directories for data and result under this directory. After that, we can train the model thorough cli.

```
$python train.py 0.001 10 16 --use_saved_data 0 --data_path resources/mjxproto --result_path result.
```

Here is the information about argument.

The first three are learning rate, epochs, batch size respectively.

`--use_saved_data` 0 means not to use saved data and other than 0 means otherwise. The default is 0.

`--round_candidates` We can specify rounds to use for training by this argument.

`--data_path` Please specify the data path.

`--result_path` Please specify the result path.









66 changes: 66 additions & 0 deletions workspace/suphx-reward-shaping/tests/test_train_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import sys

import jax
import jax.numpy as jnp
import optax

sys.path.append("../")
from train_helper import initializa_params, loss, net, plot_result, save_params, train
from utils import to_data

layer_sizes = [3, 4, 5, 4]
feature_size = 19
seed = jax.random.PRNGKey(42)
save_dir = os.path.join(os.pardir, "trained_model/test_param.pickle")
result_dir = os.path.join(os.pardir, "result")

mjxprotp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources")


def test_initialize_params():
params = initializa_params(layer_sizes, feature_size, seed)
assert len(params) == 4


def test_train():
params = initializa_params(layer_sizes, feature_size, seed)
features, scores = to_data(mjxprotp_dir)
optimizer = optax.adam(0.05)
params, train_log, test_log = train(
params, optimizer, features, scores, features, scores, epochs=1, batch_size=1
)
assert len(params) == 4


def test_save_model():
params = initializa_params(layer_sizes, feature_size, seed)
features, scores = to_data(mjxprotp_dir)
optimizer = optax.adam(0.05)
params = train(params, optimizer, features, scores, features, scores, epochs=1, batch_size=1)
save_params(params, save_dir)


def test_plot_result():
params = initializa_params(layer_sizes, feature_size, seed)
features, scores = to_data(mjxprotp_dir)
optimizer = optax.adam(0.05)
params = train(params, optimizer, features, scores, features, scores, epochs=1, batch_size=1)
plot_result(params, features, scores, result_dir)


def test_net():
params = initializa_params(layer_sizes, feature_size, seed)
features, scores = to_data(mjxprotp_dir)
print(net(features[0], params), features, params)


def test_loss():
params = initializa_params(layer_sizes, feature_size, seed)
features, scores = to_data(mjxprotp_dir)
print(loss(params, features, scores))


if __name__ == "__main__":
test_net()
test_loss()
Loading

0 comments on commit 1a120f7

Please sign in to comment.