-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathppo_eval.py
29 lines (18 loc) · 1018 Bytes
/
ppo_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# add parent dir to find package. Only needed for source code build, pip install doesn't need it.
import os, inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(os.path.dirname(currentdir))
os.sys.path.insert(0, parentdir)
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_vec_env
from pushGymEnv import pushGymEnv
def main():
model_name = 'test_model0'
env = pushGymEnv(renders=True)
model = PPO.load(os.path.join(currentdir, 'Results/model1_cont/model1_cont.zip'), env=env) # PPO.load(os.path.join(currentdir, 'Results/New_Rewards', model_name), env=env)
print(evaluate_policy(model=model, env=env, n_eval_episodes=10, render=True, return_episode_rewards=True))
if __name__=='__main__':
main()