-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_env.py
60 lines (49 loc) · 1.83 KB
/
test_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import gymnasium as gym
import numpy as np
from stable_baselines3.common.noise import (
NormalActionNoise,
OrnsteinUhlenbeckActionNoise,
)
import frasa_env
gym.register_envs(frasa_env)
argparser = argparse.ArgumentParser(description="Test the sigmaban-standup-v0 environment")
argparser.add_argument("--env", type=str, default="frasa-standup-v0", help="Environment to test")
argparser.add_argument("--random", action="store_true", help="Use random actions instead of zeros")
argparser.add_argument("--normal", action="store_true", help="Use normal action noise")
argparser.add_argument("--orn", action="store_true", help="Use Ornstein-Uhlenbeck action noise")
argparser.add_argument("--std", type=float, default=0.1, help="Standard deviation for the action noise")
argparser.add_argument("--theta", type=float, default=0.15, help="Theta for the Ornstein-Uhlenbeck noise")
args = argparser.parse_args()
env = gym.make(args.env)
env.reset()
noise = None
returns = 0
step = 0
if args.normal:
noise = NormalActionNoise(
mean=np.zeros(env.action_space.shape[0]),
sigma=args.std * np.ones(env.action_space.shape[0]),
)
elif args.orn:
noise = OrnsteinUhlenbeckActionNoise(
mean=np.zeros(env.action_space.shape[0]),
sigma=args.std * np.ones(env.action_space.shape[0]),
theta=args.theta,
)
while True:
step += 1
action = env.action_space.sample()
if not args.random:
action = np.zeros_like(action)
if noise is not None:
action += noise()
obs, reward, done, trucated, infos = env.step(action)
returns += reward
env.render()
if done or trucated:
status = "truncated" if trucated else "done"
print(f"Episode finished ({status}) after {step} steps, returns: {returns}")
step = 0
returns = 0
env.reset()