-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlikelihood_free_inference.py
179 lines (149 loc) · 7.65 KB
/
likelihood_free_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import torch
from utils import get_naming_str
from sbi import utils as utils
import multiprocessing as mp
import numpy as np
import pickle
import tqdm
import itertools
from getdist import MCSamples
import sys
sys.path.append('../multifrequency_pipeline')
sys.path.append('../harmonic_ILC_pipeline')
sys.path.append('../needlet_ILC_pipeline')
import multifrequency_data_vecs
import hilc_analytic
import nilc_data_vecs
import sbi_utils
import hyperparam_sweep
def get_prior(inp):
'''
ARGUMENTS
---------
inp: Info object containing input parameter specifications
RETURNS
-------
prior on Acomp1, etc. to use for likelihood-free inference
'''
num_dim = len(inp.comps)
mean_tensor = torch.ones(num_dim)
prior = utils.BoxUniform(low=mean_tensor-torch.tensor(inp.prior_half_widths) , high=mean_tensor+torch.tensor(inp.prior_half_widths))
return prior
def get_observation(inp, pipeline, env):
'''
ARGUMENTS
---------
inp: Info object containing input parameter specifications
pipeline: str, either 'multifrequency', 'HILC', or 'NILC'
env: environment object, only needed if pipeline=='NILC
RETURNS
-------
data_vec: ndarray containing outputs from simulation
Clpq of shape (Nsims, Ncomps, Ncomps, Nbins) if HILC or NILC
Clij of shape (Nsims, Nfreqs, Nfreqs, Nbins) if multifrequency
'''
sims_for_obs = min(inp.Nsims, 1000)
if pipeline == 'HILC':
fname = 'Clpq'
pool = mp.Pool(inp.num_parallel)
args = [(inp, sim) for sim in range(sims_for_obs)]
print(f'Running {sims_for_obs} simulations of frequency-frequency power spectra as part of observation vector calculation...', flush=True)
Clij = list(tqdm.tqdm(pool.imap(hilc_analytic.get_freq_power_spec_star, args), total=sims_for_obs))
pool.close()
Clij = np.asarray(Clij, dtype=np.float32)
pool = mp.Pool(inp.num_parallel)
inp.Clij_theory = np.mean(Clij, axis=0)
args = [(inp, Clij[sim]) for sim in range(sims_for_obs)]
print(f'Running {sims_for_obs} simulations to average together for observation vector...', flush=True)
Clpq = list(tqdm.tqdm(pool.imap(hilc_analytic.get_data_vecs_star, args), total=sims_for_obs))
pool.close()
data_vec = np.asarray(Clpq, dtype=np.float32)[:,:,:,0,:] # shape (Nsims, Ncomps, Ncomps, Nbins)
else:
pool = mp.Pool(inp.num_parallel)
print(f'Running {sims_for_obs} simulations to average together for observation vector...', flush=True)
if pipeline == 'multifrequency':
func = multifrequency_data_vecs.get_data_vectors_star
args = [(inp, sim) for sim in range(sims_for_obs)]
elif pipeline == 'NILC':
func = nilc_data_vecs.get_data_vectors_star
args = [(inp, env, sim) for sim in range(sims_for_obs)]
data_vec = list(tqdm.tqdm(pool.imap(func, args), total=sims_for_obs))
pool.close()
if pipeline == 'NILC':
fname = 'Clpq'
data_vec = np.asarray(data_vec, dtype=np.float32) # shape (Nsims, Ncomps, Ncomps, Nbins)
else:
fname = 'Clij'
data_vec = np.asarray(data_vec, dtype=np.float32)[:,:,:,0,:] # shape (Nsims, Nfreqs, Nfreqs, Nbins)
naming_str = get_naming_str(inp, pipeline)
pickle.dump(data_vec, open(f'{inp.output_dir}/data_vecs/{fname}_{naming_str}.p', 'wb'), protocol=4)
print(f'\nsaved {inp.output_dir}/data_vecs/{fname}_{naming_str}.p', flush=True)
return data_vec
def get_posterior(inp, pipeline, env):
'''
ARGUMENTS
---------
inp: Info object containing input parameter specifications
pipeline: str, either 'multifrequency', 'HILC', or 'NILC'
env: environment object, only needed if pipeline=='NILC'
RETURNS
-------
samples: torch tensor of shape (Nsims, Ncomps) containing Acomp1, etc. posteriors
'''
assert pipeline in {'multifrequency', 'HILC', 'NILC'}, "pipeline must be either 'multifrequency', 'HILC', or 'NILC'"
prior = get_prior(inp)
try:
naming_str = get_naming_str(inp, pipeline)
a_array = pickle.load(open(f'{inp.output_dir}/posteriors/a_array_{naming_str}.p', 'rb'))
except Exception:
observation_all_sims = get_observation(inp, pipeline, env)
N = observation_all_sims.shape[1]
observation_all_sims = np.array([observation_all_sims[:,i,j] for (i,j) in list(itertools.product(range(N), range(N)))])
observation_all_sims = np.transpose(observation_all_sims, axes=(1,0,2)).reshape((-1, len(observation_all_sims)*inp.Nbins))
mean_vec = np.mean(observation_all_sims, axis=0)
std_dev_vec = np.std(observation_all_sims, axis=0)
observation = np.zeros_like(mean_vec)
def simulator(pars):
'''
ARGUMENTS
---------
pars: [Acomp1, Acomp2, etc.] parameters (floats)
RETURNS
-------
data_vec: torch tensor containing outputs from simulation
Clpq of shape (Ncomps*Ncomps*Nbins,) if HILC or NILC
Clij of shape (Nfreqs*Nfreqs*Nbins, ) if multifrequency
'''
if pipeline == 'multifrequency':
data_vec = multifrequency_data_vecs.get_data_vectors(inp, sim=None, pars=pars)[:,:,0,:] # shape (Nfreqs, Nfreqs, Nbins)
elif pipeline == 'HILC':
Clij = hilc_analytic.get_freq_power_spec(inp, sim=None, pars=pars) # shape (Nfreqs, Nfreqs, 1+Ncomps, ellmax+1)
data_vec = hilc_analytic.get_data_vecs(inp, Clij)[:,:,0,:] # shape (Ncomps, Ncomps, Nbins)
elif pipeline == 'NILC':
data_vec = nilc_data_vecs.get_data_vectors(inp, env, sim=None, pars=pars) # shape (Ncomps, Ncomps, Nbins)
data_vec = np.array([data_vec[i,j] for (i,j) in list(itertools.product(range(N), range(N)))]).flatten()
data_vec = torch.tensor((data_vec-mean_vec)/std_dev_vec)
return data_vec
if inp.tune_hyperparameters:
samples, mean_stds, error_of_stds = hyperparam_sweep.run_sweep(inp, prior, simulator, observation, pipeline)
for i, par in enumerate([f'A{comp}' for comp in inp.comps]):
print(f'mean of {par} posterior standard deviations over top 25% of sweeps: ', mean_stds[i], flush=True)
print(f'standard deviation of {par} posterior standard deviations ("error of errors") over top 25% of sweeps: ', error_of_stds[i], flush=True)
else:
samples = sbi_utils.flexible_single_round_SNPE(inp, prior, simulator, observation,
learning_rate=inp.learning_rate,
stop_after_epochs=inp.stop_after_epochs,
clip_max_norm=inp.clip_max_norm,
num_transforms=inp.num_transforms,
hidden_features=inp.hidden_features)
a_array = np.array(samples, dtype=np.float32).T
naming_str = get_naming_str(inp, pipeline)
pickle.dump(a_array, open(f'{inp.output_dir}/posteriors/a_array_{naming_str}.p', 'wb'))
print(f'\nsaved {inp.output_dir}/posteriors/a_array_{naming_str}.p')
print('Results from Likelihood-Free Inference', flush=True)
print('----------------------------------------------', flush=True)
names = [f'A{comp}' for comp in inp.comps]
samples_MC = MCSamples(samples=a_array.T, names = names, labels = names)
for par in names:
print(samples_MC.getInlineLatex(par,limit=1), flush=True)
return samples