-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
42 lines (32 loc) · 1.54 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from simulator import Simulator
from analysis import state_heatmap, action_heatmap, weight_heatmap, policy_heatmap, success_metrics, plot_trends
from simulation_config import config_
import numpy as np
from itertools import product
def verbose_helper(config, results, n_reps, **kwargs):
if 'success_metrics' in kwargs:
return success_metrics(config, results, n_reps, **kwargs['success_metrics'])
def plot_helper(config, results, n_reps, **kwargs):
if 'state_heatmap' in kwargs:
state_heatmap(config, results, n_reps, **kwargs['state_heatmap'])
if 'action_heatmap' in kwargs:
action_heatmap(config, results, n_reps, **kwargs['action_heatmap'])
if 'weight_heatmap' in kwargs:
weight_heatmap(config, results, n_reps, **kwargs['weight_heatmap'])
if 'policy_heatmap' in kwargs:
policy_heatmap(config, results, n_reps, **kwargs['policy_heatmap'])
if 'trends' in kwargs:
plot_trends(config, results, n_reps, **kwargs['trends'])
def grid_search():
...
def main(config):
simulator = Simulator(config['environment_params'], config['model_params'])
results = simulator.run(reps=config['n_reps'], steps=config['epochs'], seed=config['seed'], thin=config['thin'])
res = None
if config['verbose']:
res = verbose_helper(config, results, config['n_reps'], test_ratio=config['test_ratio'], **config['verbose_params'])
if config['plot']:
plot_helper(config, results, config['n_reps'], **config['plot_params'])
return res
if __name__ == "__main__":
main(config_)