Skip to content

Commit

Permalink
Towards a slightly but not entirely less hacky combined likelihood.
Browse files Browse the repository at this point in the history
  • Loading branch information
robertsjames committed Feb 28, 2024
1 parent 0e42675 commit ca86332
Showing 1 changed file with 39 additions and 11 deletions.
50 changes: 39 additions & 11 deletions flamedisx/non_asymptotic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,18 @@ class TestStatistic():
def __init__(self, likelihood):
self.likelihood = likelihood

def __call__(self, mu_test, signal_source_name, guess_dict):
def __call__(self, mu_test, signal_source_name, guess_dict, transform_params):
# To fix the signal RM in the conditional fit
fix_dict = {f'{signal_source_name}_rate_multiplier': mu_test}
if f'{signal_source_name}_rate_multiplier' in self.likelihood.param_defaults:
fix_dict = {f'{signal_source_name}_rate_multiplier': mu_test}
else:
fix_dict = {transform_params[signal_source_name][0]: mu_test}

guess_dict_nuisance = guess_dict.copy()
guess_dict_nuisance.pop(f'{signal_source_name}_rate_multiplier')
if f'{signal_source_name}_rate_multiplier' in self.likelihood.param_defaults:
guess_dict_nuisance.pop(f'{signal_source_name}_rate_multiplier')
else:
guess_dict_nuisance.pop(transform_params[signal_source_name][0])

# Conditional fit
bf_conditional = self.likelihood.bestfit(fix=fix_dict, guess=guess_dict_nuisance, suppress_warnings=True)
Expand Down Expand Up @@ -164,6 +170,9 @@ def __init__(
sample_other_constraints: ty.Dict[str, ty.Callable] = None,
rm_bounds: ty.Dict[str, ty.Tuple[float, float]] = None,
log_constraint_fn: ty.Callable = None,
ignore_rms: ty.Tuple[str] = None,
transform_params: ty.Dict[str, ty.Tuple[str, ty.Tuple]] = None,
likelihood_class = fd.LogLikelihood,
ntoys=1000,
batch_size=10000):

Expand Down Expand Up @@ -206,6 +215,13 @@ def log_constraint_fn(**kwargs):
self.sample_other_constraints = sample_other_constraints
self.rm_bounds = rm_bounds

self.ignore_rms = ignore_rms
if transform_params is None:
transform_params = dict()
self.transform_params = transform_params

self.likelihood_class = likelihood_class

def run_routine(self, mus_test=None, save_fits=False,
observed_data=None,
observed_test_stats=None,
Expand Down Expand Up @@ -274,11 +290,23 @@ def run_routine(self, mus_test=None, save_fits=False,
arguments[signal_source] = self.arguments[signal_source]

# Create likelihood of TemplateSources
likelihood = fd.LogLikelihood(sources=sources,
arguments=arguments,
progress=False,
batch_size=self.batch_size,
free_rates=tuple([sname for sname in sources.keys()]))
if self.ignore_rms is None:
ignore_rms = ()
else:
ignore_rms = self.ignore_rms
free_rates = tuple([sname for sname in sources.keys() if sname not in ignore_rms])

kwargs = dict()
for source in sources:
if source in self.transform_params:
kwargs[self.transform_params[source][0]] = self.transform_params[source][1]

likelihood = self.likelihood_class(sources=sources,
arguments=arguments,
progress=False,
batch_size=self.batch_size,
free_rates=free_rates,
**kwargs)

rm_bounds = dict()
if signal_source in self.rm_bounds.keys():
Expand Down Expand Up @@ -396,7 +424,7 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,
if value < 0.1:
guess_dict_SB[key] = 0.1
# Evaluate test statistic
ts_result_SB = test_statistic_SB(mu_test, signal_source_name, guess_dict_SB)
ts_result_SB = test_statistic_SB(mu_test, signal_source_name, guess_dict_SB, self.transform_params)
# Save test statistic, and possibly fits
ts_values_SB.append(ts_result_SB[0])
if save_fits:
Expand Down Expand Up @@ -424,7 +452,7 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,
# Create test statistic
test_statistic_B = self.test_statistic(likelihood)
# Evaluate test statistic
ts_result_B = test_statistic_B(mu_test, signal_source_name, guess_dict_B)
ts_result_B = test_statistic_B(mu_test, signal_source_name, guess_dict_B, self.transform_params)
# Save test statistic, and possibly fits
ts_values_B.append(ts_result_B[0])
if save_fits:
Expand Down Expand Up @@ -466,7 +494,7 @@ def get_observed_test_stat(self, observed_test_stats, observed_data,
if value < 0.1:
guess_dict[key] = 0.1
# Evaluate test statistic
ts_result = test_statistic(mu_test, signal_source_name, guess_dict)
ts_result = test_statistic(mu_test, signal_source_name, guess_dict, self.transform_params)

# Add to the test statistic collection
observed_test_stats.add_test_stat(mu_test, ts_result[0])
Expand Down

0 comments on commit ca86332

Please sign in to comment.