Skip to content

Commit

Permalink
adding self, fixing plotting utilities
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Jan 29, 2024
1 parent 5145299 commit eac8bdf
Showing 1 changed file with 42 additions and 28 deletions.
70 changes: 42 additions & 28 deletions src/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def predict(input, model):
"""
return 0

def simulator()


class Display:
Expand Down Expand Up @@ -87,7 +86,7 @@ def mackelab_corner_plot(self,
figsize=(5, 5)
)
axes[0, 1].plot([truth_list[1]], [truth_list[0]], marker="o", color="red")
def improved_corner_plot(self, posterior, params):
def improved_corner_plot(self, posterior):
"""
Improved corner plot
"""
Expand Down Expand Up @@ -136,7 +135,11 @@ def generate_sbc_samples(self,
)
return thetas, ys, ranks, dap_samples

def sbc_statistics(ranks, thetas, dap_samples, num_posterior_samples):
def sbc_statistics(self,
ranks,
thetas,
dap_samples,
num_posterior_samples):
'''
The ks pvalues are vanishingly small here, so we can reject the null hypothesis (of the marginal rank distributions being equivalent to an uniform distribution). The inference clearly went wrong.
Expand All @@ -148,7 +151,8 @@ def sbc_statistics(ranks, thetas, dap_samples, num_posterior_samples):
ranks, thetas, dap_samples, num_posterior_samples=num_posterior_samples
)
return check_stats
def plot_1d_ranks(ranks,
def plot_1d_ranks(self,
ranks,
num_posterior_samples,
labels_list,
colorlist,
Expand Down Expand Up @@ -185,7 +189,8 @@ def plot_1d_ranks(ranks,
if save:
plt.savefig(path+'sbc_ranks.pdf')

def plot_cdf_1d_ranks(ranks,
def plot_cdf_1d_ranks(self,
ranks,
num_posterior_samples,
labels_list,
colorlist,
Expand Down Expand Up @@ -214,23 +219,25 @@ def plot_cdf_1d_ranks(ranks,
if save:
plt.savefig(path+'sbc_ranks_cdf.pdf')

def calculate_coverage_fraction(posterior,
truth_array,
x_observed,
def calculate_coverage_fraction(self,
posterior,
thetas,
ys,
percentile_list,
samples_per_inference = 1000):
samples_per_inference=1_000):
"""
posterior --> the trained posterior
x_observed --> the data used for inference
truth_array --> true parameter values
thetas --> true parameter values
ys --> the "observed" data used for inference
"""
# this holds all posterior samples for each inference run
all_samples = np.empty((len(x_observed), samples_per_inference, np.shape(truth_array)[1]))
all_samples = np.empty((len(ys), samples_per_inference, np.shape(thetas)[1]))
count_array = []
# make this for loop into a progress bar:
for i in tqdm(range(len(x_observed)), desc='Processing observations', unit='obs'):
for i in tqdm(range(len(ys)), desc='Processing observations', unit='obs'):
# sample from the trained posterior n_sample times for each observation
samples = posterior.sample(sample_shape=(samples_per_inference,), x=x_observed[i]).cpu()
samples = posterior.sample(sample_shape=(samples_per_inference,), x=ys[i]).cpu()

'''
# plot posterior samples
Expand Down Expand Up @@ -258,17 +265,18 @@ def calculate_coverage_fraction(posterior,
# this is asking if the true parameter value is contained between the
# upper and lower confidence intervals
# checks separately for each side of the 50th percentile
count = np.logical_and(confidence_u - truth_array.T[:,i] > 0, truth_array.T[:,i] - confidence_l > 0)
count = np.logical_and(confidence_u - thetas.T[:,i] > 0, thetas.T[:,i] - confidence_l > 0)
count_vector.append(count)
# each time the above is > 0, adds a count
count_array.append(count_vector)
count_sum_array = np.sum(count_array, axis=0)
frac_lens_within_vol = np.array(count_sum_array)
return all_samples, np.array(frac_lens_within_vol)/len(x_observed)
return all_samples, np.array(frac_lens_within_vol)/len(ys)



def plot_coverage_fraction(posterior,
def plot_coverage_fraction(self,
posterior,
thetas,
ys,
samples_per_inference,
Expand Down Expand Up @@ -337,7 +345,10 @@ def run_all_sbc(self,
colorlist,
num_sbc_runs=1_000,
num_posterior_samples=1_000,
params):
samples_per_inference=1_000,
plot=True,
save=False,
path='../plots/'):
"""
Runs and displays mackelab's SBC (simulation-based calibration)
Expand All @@ -351,23 +362,26 @@ def run_all_sbc(self,
num_sbc_runs,
num_posterior_samples)

stats = self.sbc_statistics(ranks, thetas, dap_samples, num_posterior_samples)
stats = self.sbc_statistics(ranks,
thetas,
dap_samples,
num_posterior_samples)
print(stats)
self.plot_1d_ranks(ranks,
num_posterior_samples,
labels_list,
colorlist,
plot=False,
save=True,
path='../../plots/')
plot=plot,
save=save,
path=path)

self.plot_cdf_1d_ranks(ranks,
num_posterior_samples,
labels_list,
colorlist,
plot=False,
save=True,
path='../../plots/')
plot=plot,
save=save,
path=path)

self.plot_coverage_fraction(posterior,
thetas,
Expand All @@ -376,9 +390,9 @@ def run_all_sbc(self,
labels_list,
colorlist,
n_percentile_steps=21,
plot=False,
save=True,
path='plots/')
plot=plot,
save=save,
path=path)



Expand Down

0 comments on commit eac8bdf

Please sign in to comment.