Skip to content

Commit

Permalink
Add discovery stuff (bit rough for now).
Browse files Browse the repository at this point in the history
  • Loading branch information
robertsjames committed Jan 10, 2024
1 parent 70125f7 commit a12beeb
Showing 1 changed file with 40 additions and 5 deletions.
45 changes: 40 additions & 5 deletions flamedisx/non_asymptotic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def run_routine(self, mus_test=None, save_fits=False,
observed_test_stats=None,
generate_B_toys=False,
simulate_dict_B=None, toy_data_B=None, constraint_extra_args_B=None,
toy_batch=0):
toy_batch=0,
discovery=False):
"""If observed_data is passed, evaluate observed test statistics. Otherwise,
obtain test statistic distributions (for both S+B and B-only).
Expand Down Expand Up @@ -315,7 +316,8 @@ def run_routine(self, mus_test=None, save_fits=False,
# Case where we want test statistic distributions
else:
self.toy_test_statistic_dist(test_stat_dists_SB, test_stat_dists_B,
mu_test, signal_source, likelihood, save_fits=save_fits)
mu_test, signal_source, likelihood,
save_fits=save_fits, discovery=discovery)

if observed_data is not None:
observed_test_stats_collection[signal_source] = observed_test_stats
Expand Down Expand Up @@ -366,7 +368,8 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood):
return simulate_dict, toy_data, constraint_extra_args

def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,
mu_test, signal_source_name, likelihood, save_fits=False):
mu_test, signal_source_name, likelihood,
save_fits=False, discovery=False):
"""Internal function to get test statistic distribution.
"""
ts_values_SB = []
Expand Down Expand Up @@ -396,7 +399,10 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,
if value < 0.1:
guess_dict_SB[key] = 0.1
# Evaluate test statistic
ts_result_SB = test_statistic_SB(mu_test, signal_source_name, guess_dict_SB)
if discovery:
ts_result_SB = test_statistic_SB(0., signal_source_name, guess_dict_SB)
else:
ts_result_SB = test_statistic_SB(mu_test, signal_source_name, guess_dict_SB)
# Save test statistic, and possibly fits
ts_values_SB.append(ts_result_SB[0])
if save_fits:
Expand Down Expand Up @@ -424,7 +430,10 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,
# Create test statistic
test_statistic_B = self.test_statistic(likelihood)
# Evaluate test statistic
ts_result_B = test_statistic_B(mu_test, signal_source_name, guess_dict_B)
if discovery:
ts_result_B = test_statistic_B(0., signal_source_name, guess_dict_B)
else:
ts_result_B = test_statistic_B(mu_test, signal_source_name, guess_dict_B)
# Save test statistic, and possibly fits
ts_values_B.append(ts_result_B[0])
if save_fits:
Expand Down Expand Up @@ -667,3 +676,29 @@ def get_bands(self, conf_level=0.1, quantiles=[0, 1, -1, 2, -2],
bands[signal_source] = these_bands

return bands

def get_bands_discovery(self, quantiles=[0, 1, -1]):
"""
"""
bands = dict()

# Loop over signal sources
for signal_source in self.signal_source_names:
# Get test statistic distribitions
test_stat_dists_SB = self.test_stat_dists_SB[signal_source]
test_stat_dists_B = self.test_stat_dists_B[signal_source]

assert len(test_stat_dists_SB.ts_dists.keys()) == 1, 'Currently only support a single signal strength'

these_p_vals = (100. - stats.percentileofscore(list(test_stat_dists_B.ts_dists.values())[0],
list(test_stat_dists_SB.ts_dists.values())[0],
kind='weak')) / 100.
these_p_vals = these_p_vals[these_p_vals > 0.]
these_disco_sigs = stats.norm.ppf(1. - these_p_vals)

these_bands = dict()
for quantile in quantiles:
these_bands[quantile] = np.quantile(np.sort(these_disco_sigs), stats.norm.cdf(quantile))
bands[signal_source] = these_bands

return bands

0 comments on commit a12beeb

Please sign in to comment.