forked from jiupinjia/rocket-recycling
-
Notifications
You must be signed in to change notification settings - Fork 1
/
example_inference.py
29 lines (23 loc) · 905 Bytes
/
example_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from rocket import Rocket
from policy import ActorCritic
import os
import glob
# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
task = 'hover' # 'hover' or 'landing'
max_steps = 800
ckpt_dir = glob.glob(os.path.join(task+'_ckpt', '*.pt'))[-1] # last ckpt
env = Rocket(task=task, max_steps=max_steps)
net = ActorCritic(input_dim=env.state_dims, output_dim=env.action_dims).to(device)
if os.path.exists(ckpt_dir):
checkpoint = torch.load(ckpt_dir)
net.load_state_dict(checkpoint['model_G_state_dict'])
state = env.reset()
for step_id in range(max_steps):
action, log_prob, value = net.get_action(state)
state, reward, done, _ = env.step(action)
env.render(window_name='test')
if env.already_crash:
break