From c243a5144efe11bed94748a2a573a986537c6b43 Mon Sep 17 00:00:00 2001 From: Simon Lund Date: Sun, 19 Sep 2021 23:53:22 +0200 Subject: [PATCH] Last fine tuning --- src/main.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/src/main.py b/src/main.py index e857449..c588204 100644 --- a/src/main.py +++ b/src/main.py @@ -14,10 +14,7 @@ from cli import run_cli_cmnds -def main(): - model_name = "blitz5k" - model = neural_net.load(f"../models/{model_name}.pt") - +def test(): # env = discrete_env_with_nn(rf.right, model) # from gym_cartpole_swingup.envs import cartpole_swingup # env = cartpole_swingup.CartPoleSwingUpV1() @@ -29,33 +26,27 @@ def main(): input("continue?") histories = agents.run(ppo, env, 10) - with open("history_ppo.p", "wb") as f: - dill.dump(histories, f) - - evaluation.plot_angles(histories[0], model_name) + evaluation.plot_angles(histories[0], "no model") -def test2(): +def main(): model_name = "blitz5k" model = neural_net.load(f"../models/{model_name}.pt") env = neural_net.USUCEnvWithNN.create(model, rf.best, "../discrete-usuc-dataset") - env.reset(1) - history = utils.random_actions(env) - evaluation.plot_angles(history, model_name) - evaluation.plot_reward_angle(history) + ppo = agents.create("ppo", env) + agents.train(ppo, total_timesteps=10000) + agents.save(ppo, "../agents/ppo") -def analysis(): - # load history for analysis - with open("history_ppo.p", "rb") as f: - history = dill.load(f) + input("continue?") + histories = agents.run(ppo, env, 10) - evaluation.plot_reward_angle(history) + evaluation.plot_angles(histories[0], model_name) + evaluation.plot_reward_angle(histories[0]) if __name__ == "__main__": - test2() + # test() # main() - # analysis() run_cli_cmnds()