-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
102 lines (81 loc) · 4.18 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from gymnasium.wrappers import TimeLimit, FrameStack
from Simple_Shapes_RL.utils import NRepeat
import torch
import numpy as np
import os
import yaml
import re
from bim_gw.modules.domain_modules import VAE
from bim_gw.modules.domain_modules.simple_shapes import SimpleShapesAttributes
from bim_gw.modules import GlobalWorkspace
from Simple_Shapes_RL.Env import Simple_Env
from Simple_Shapes_RL.gw_wrapper import GWWrapper, NormWrapper
policy_kwargs = dict(activation_fn=torch.nn.ReLU,
net_arch=[dict(pi=[128, 128, 128], vf=[128, 128, 128])])
MODE = {'attr': ['attr'],
'v': ['v'],
'gw_attr': ['gw_attr', 'gw_v'],
'gw_v': ['gw_v', 'gw_attr']
}
MODE_PATH = {'attr': 'attr', 'v': 'v', 'gw_attr': 'CLIPattr', 'gw_v': 'CLIPv'}
current_directory = os.getcwd()
os.environ["SS_PATH"] = os.getcwd()
path_matcher = re.compile(r'\$\{([^}^{]+)\}')
scientific_number_matcher = re.compile(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X
)
def path_constructor(loader, node):
value = node.value
match = path_matcher.match(value)
env_var = match.group()[2:-1]
return os.environ.get(env_var) + value[match.end():]
yaml.add_implicit_resolver('!path', path_matcher)
yaml.add_implicit_resolver(u'tag:yaml.org,2002:float', scientific_number_matcher, list(u'-+0123456789.'))
yaml.add_constructor('!path', path_constructor)
if __name__ == '__main__':
with open('cfg/cfg_test.yaml', encoding="utf-8") as f:
config = yaml.full_load(f)
vae = VAE.load_from_checkpoint(config['models_path']['VAE'], strict=False).eval().to("cuda:0")
domains = {'v': vae.eval(), 'attr': SimpleShapesAttributes(32).eval()}
gw = GlobalWorkspace.load_from_checkpoint(config['models_path']['GW'], domain_mods=domains, strict=False).eval().to("cuda:0")
gw_model = {'VAE': vae, 'GW': gw}
for mode in MODE[config['mode']]:
env = Simple_Env(render_mode=None)
env = GWWrapper(env, model=gw_model, mode=mode)
env = NormWrapper(env, norm=config['normalize'])
env = TimeLimit(env, max_episode_steps=config['episode_len'])
env = NRepeat(env, num_frames=config['n_repeats'])
env = FrameStack(env, 4)
env = Monitor(env, allow_early_resets=True)
env = DummyVecEnv([lambda: env])
model = PPO.load(f"/home/leopold/Documents/Projets/RL/RL_Simple_Shapes/models/CLIP/{config['checkpoint']}/model")
obs = env.reset()
i = 0
len_episode = 0
total_reward = 0
reward_array = np.zeros(1000)
len_array = np.zeros(1000)
while i != 1000:
action, _states = model.predict(obs[0]) # VecEnv --> list Env if more than one
obs, reward, done, info = env.step(np.array([action])) # VecEnv --> list Env if more than one
len_episode += 1
total_reward += reward
if done:
reward_array[i] = total_reward
len_array[i] = len_episode
i += 1
len_episode = 0
total_reward = 0
obs = env.reset()
if not os.path.exists(current_directory + f"/results/inference/{MODE_PATH[config['mode']]}/"):
os.makedirs(current_directory + f"/results/inference/{MODE_PATH[config['mode']]}/")
np.save(current_directory + f"/results/inference/{MODE_PATH[config['mode']]}/reward_{MODE_PATH[mode]}_from_{MODE_PATH[config['mode']]}_{config['checkpoint']}", reward_array)
np.save(current_directory + f"/results/inference/{MODE_PATH[config['mode']]}/len_{MODE_PATH[mode]}_from_{MODE_PATH[config['mode']]}_{config['checkpoint']}", len_array)