-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpyflyt_evaluate.py
100 lines (82 loc) · 2.88 KB
/
pyflyt_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from waypoints_flat_wrapper import BSVFlattenWaypointEnv
from gymnasium import spaces
import numpy as np
from stable_baselines3 import PPO
from gymnasium.envs.registration import register
import gymnasium as gym
import argparse
import time
from waypoints_env import Quadx_waypoints_2
def register_the_env():
register(
id='PyFlyt/QuadXBSV-Waypoints-v1',
entry_point=Quadx_waypoints_2,
max_episode_steps=200,
)
def parse_args():
parser = argparse.ArgumentParser(description="PyFlyt Waypoint Evaluation")
parser.add_argument('--waypoints', nargs=3, type=float, required=True, help="Waypoint coordinates (x, y, z)")
args = parser.parse_args()
return np.array(args.waypoints).reshape(1, 3)
def evaluate(model_path, waypoints_input):
init_waypoints = [[0, 1, 1]]
register_the_env()
env = BSVFlattenWaypointEnv(gym.make(id='PyFlyt/QuadXBSV-Waypoints-v1', flight_mode=-1), BSV_waypoints=init_waypoints)
env.action_space = spaces.Box(low = np.array(
[
-1.0,
-1.0,
-1.0,
-1.0,
]
), high = np.array(
[
1.0,
1.0,
1.0,
1.0,
]
), dtype=np.float64)
model_loaded = PPO.load(model_path, env=env)
# f'/mnt/c/Users/tyler/OneDrive/Desktop/pyflyt/best_model.zip'
obs_list = []
obs, info = env.reset(BSV_waypoints=waypoints_input)
reward_list = []
action_list = []
target_list = []
obs_array_list = []
start = time.time()
terminated = False
step = 0
while not terminated:
action, _states = model_loaded.predict(obs,
deterministic=True
)
# obs, reward, terminated, truncated, info = env.step(np.zeros((4))+.79)
obs, reward, terminated, truncated, info = env.step(action)
obs_list += [obs]
reward_list += [reward]
action_list += [action]
# if info['num_targets_reached'] < len(waypoints_input):
# print("failure")
# else:
# print("success")
if info['num_targets_reached'] == 1:
print("success")
break
elif info['out_of_bounds'] or info['collision'] or terminated or truncated:
print("failure")
break
else:
step +=1
if step > 100:
print("failure")
break
env.close()
obs_array = np.array(obs_list)
reward_array = np.array(reward_list)
action_array = np.array(action_list)
# targets_array = np.array(target_waypoint_local)
if __name__ == "__main__":
waypoints = parse_args()
evaluate('src/best_model.zip', waypoints)