-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Train reward-shaping model with large data (#1134)
* implement train.py * fix feature * fix * fix * fix * add readme * fix * fix features * fix * fix * fix typo * fix
- Loading branch information
Showing
17 changed files
with
549 additions
and
282 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
40 changes: 0 additions & 40 deletions
40
workspace/suphnx-reward-shaping/tests/test_train_helper.py
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
## Suphnx-like reward shaping | ||
|
||
## How to train the model | ||
|
||
Prepare the directories for data and result under this directory. After that, we can train the model thorough cli. | ||
|
||
``` | ||
$python train.py 0.001 10 16 --use_saved_data 0 --data_path resources/mjxproto --result_path result. | ||
``` | ||
|
||
Here is the information about argument. | ||
|
||
The first three are learning rate, epochs, batch size respectively. | ||
|
||
`--use_saved_data` 0 means not to use saved data and other than 0 means otherwise. The default is 0. | ||
|
||
`--round_candidates` We can specify rounds to use for training by this argument. | ||
|
||
`--data_path` Please specify the data path. | ||
|
||
`--result_path` Please specify the result path. | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import sys | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import optax | ||
|
||
sys.path.append("../") | ||
from train_helper import initializa_params, loss, net, plot_result, save_params, train | ||
from utils import to_data | ||
|
||
layer_sizes = [3, 4, 5, 4] | ||
feature_size = 19 | ||
seed = jax.random.PRNGKey(42) | ||
save_dir = os.path.join(os.pardir, "trained_model/test_param.pickle") | ||
result_dir = os.path.join(os.pardir, "result") | ||
|
||
mjxprotp_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "resources") | ||
|
||
|
||
def test_initialize_params(): | ||
params = initializa_params(layer_sizes, feature_size, seed) | ||
assert len(params) == 4 | ||
|
||
|
||
def test_train(): | ||
params = initializa_params(layer_sizes, feature_size, seed) | ||
features, scores = to_data(mjxprotp_dir) | ||
optimizer = optax.adam(0.05) | ||
params, train_log, test_log = train( | ||
params, optimizer, features, scores, features, scores, epochs=1, batch_size=1 | ||
) | ||
assert len(params) == 4 | ||
|
||
|
||
def test_save_model(): | ||
params = initializa_params(layer_sizes, feature_size, seed) | ||
features, scores = to_data(mjxprotp_dir) | ||
optimizer = optax.adam(0.05) | ||
params = train(params, optimizer, features, scores, features, scores, epochs=1, batch_size=1) | ||
save_params(params, save_dir) | ||
|
||
|
||
def test_plot_result(): | ||
params = initializa_params(layer_sizes, feature_size, seed) | ||
features, scores = to_data(mjxprotp_dir) | ||
optimizer = optax.adam(0.05) | ||
params = train(params, optimizer, features, scores, features, scores, epochs=1, batch_size=1) | ||
plot_result(params, features, scores, result_dir) | ||
|
||
|
||
def test_net(): | ||
params = initializa_params(layer_sizes, feature_size, seed) | ||
features, scores = to_data(mjxprotp_dir) | ||
print(net(features[0], params), features, params) | ||
|
||
|
||
def test_loss(): | ||
params = initializa_params(layer_sizes, feature_size, seed) | ||
features, scores = to_data(mjxprotp_dir) | ||
print(loss(params, features, scores)) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_net() | ||
test_loss() |
Oops, something went wrong.