Skip to content

Commit

Permalink
messy updates for SR1+SR3
Browse files Browse the repository at this point in the history
  • Loading branch information
Makayla Trask committed Mar 15, 2024
1 parent cd8eae2 commit f6355ad
Showing 1 changed file with 48 additions and 13 deletions.
61 changes: 48 additions & 13 deletions flamedisx/non_asymptotic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,45 @@ def guess_value_lower_lim(self, guess_dict, signal_source_name,
if value < 0.1:
guess_dict[key] = 0.1
else:
assert transform_params[signal_source_name][0] == key, "Logic does not hold"
if transform_fns_inverse[signal_source_name](self.likelihood.sources[signal_source_name],
for transform_param_name in transform_params.keys():
if signal_source_name in transform_param_name:
assert transform_params[transform_param_name][0] == key, "Logic does not hold"
break

if transform_fns_inverse[transform_param_name](signal_source_name, self.likelihood.sources,
value) < 0.1:
guess_dict[key] = transform_fns[signal_source_name](self.likelihood.sources[signal_source_name],
guess_dict[key] = transform_fns[transform_param_name](signal_source_name, self.likelihood.sources,
0.1)

#print(f'signal source name: {signal_source_name}')
#print(f'guess_dict for {key}: {guess_dict}')
return guess_dict

def __call__(self, mu_test, signal_source_name, guess_dict,
transform_params, transform_fns, transform_fns_inverse):

# To fix the signal RM in the conditional fit
if f'{signal_source_name}_rate_multiplier' in self.likelihood.param_defaults:
fix_dict = {f'{signal_source_name}_rate_multiplier': mu_test}
else:
fix_dict = {transform_params[signal_source_name][0]: transform_fns[signal_source_name](self.likelihood.sources[signal_source_name],
for transform_param_name in transform_params.keys():
if signal_source_name in transform_param_name:
fix_dict = {transform_params[transform_param_name][0]: transform_fns[transform_param_name](signal_source_name, self.likelihood.sources,
mu_test)}

# source_name = signal_source_name.split("_")[0] was above not sure if will want later

guess_dict_nuisance = guess_dict.copy()
if f'{signal_source_name}_rate_multiplier' in self.likelihood.param_defaults:
guess_dict_nuisance.pop(f'{signal_source_name}_rate_multiplier')
else:
guess_dict_nuisance.pop(transform_params[signal_source_name][0])
for transform_param_name in transform_params.keys():
if signal_source_name in transform_param_name:
guess_dict_nuisance.pop(transform_params[transform_param_name][0])
break

guess_dict = self.guess_value_lower_lim(guess_dict, signal_source_name,
transform_params, transform_fns, transform_fns_inverse)

guess_dict_nuisance = self.guess_value_lower_lim(guess_dict_nuisance, signal_source_name,
transform_params, transform_fns, transform_fns_inverse)

Expand Down Expand Up @@ -317,8 +331,14 @@ def run_routine(self, mus_test=None, save_fits=False,
for background_source in self.background_source_names:
sources[background_source] = self.sources[background_source]
arguments[background_source] = self.arguments[background_source]
sources[signal_source] = self.sources[signal_source]
arguments[signal_source] = self.arguments[signal_source]

# ugly way of dealing with this -- there are more ugly things throughout that will get fixed
# just trying to get a fit first before prettying everything up
for signal_source in self.signal_source_names: # input name i.e. WIMP9
for key in self.sources.keys(): ## actual sources i.e. WIMP9_SR1, WIMP9_SR3
if signal_source in key:
sources[key] = self.sources[key]
arguments[key] = self.arguments[key]

# Create likelihood of TemplateSources
if self.ignore_rms is None:
Expand Down Expand Up @@ -365,15 +385,19 @@ def run_routine(self, mus_test=None, save_fits=False,
if f'{signal_source}_rate_multiplier' in likelihood.param_defaults:
simulate_dict_B.pop(f'{signal_source}_rate_multiplier')
else:
print(f'signal source in if generate b toys: {signal_source}')
simulate_dict_B.pop(self.transform_params[signal_source][0])

return simulate_dict_B, toy_data_B_all, constraint_extra_args_B_all

these_mus_test = mus_test[signal_source]
print(these_mus_test)
# Loop over signal rate multipliers
for mu_test in tqdm(these_mus_test, desc='Scanning over mus'):
# Case where we want observed test statistics
print(f'in run routine - mu_test: {mu_test}')
if observed_data is not None:
print('in if observed data is not none')
self.get_observed_test_stat(observed_test_stats, observed_data,
mu_test, signal_source, likelihood, save_fits=save_fits)
# Case where we want test statistic distributions
Expand Down Expand Up @@ -427,7 +451,8 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood):
if f'{signal_source_name}_rate_multiplier' in likelihood.param_defaults:
simulate_dict[f'{signal_source_name}_rate_multiplier'] = mu_test
else:
simulate_dict[self.transform_params[signal_source_name][0]] = self.transform_fns[signal_source_name](likelihood.sources[signal_source_name],
source_name = signal_source_name.split("_")[0]
simulate_dict[self.transform_params[signal_source_name][0]] = self.transform_fns[signal_source_name](source_name, likelihood.sources,
mu_test)

toy_data = likelihood.simulate(**simulate_dict)
Expand Down Expand Up @@ -479,7 +504,8 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,
if f'{signal_source_name}_rate_multiplier' in likelihood.param_defaults:
guess_dict_B[f'{signal_source_name}_rate_multiplier'] = 0.
else:
guess_dict_B[self.transform_params[signal_source_name][0]] = self.transform_fns[signal_source_name](likelihood.sources[signal_source_name],
source_name = signal_source_name.split("_")[0]
guess_dict_B[self.transform_params[signal_source_name][0]] = self.transform_fns[signal_source_name](source_name, likelihood.sources,
0.)

toy_data_B = self.toy_data_B[toy+(self.toy_batch*self.ntoys)]
Expand Down Expand Up @@ -527,17 +553,26 @@ def get_observed_test_stat(self, observed_test_stats, observed_data,

# Set data
likelihood.set_data(observed_data)

# Create test statistic
test_statistic = self.test_statistic(likelihood)

print(f'mu test in get_obs: {mu_test}')

# Guesses for fit
if f'{signal_source_name}_rate_multiplier' in likelihood.param_defaults:
guess_dict = {f'{signal_source_name}_rate_multiplier': mu_test}
else:
guess_dict = {self.transform_params[signal_source_name][0]: self.transform_fns[signal_source_name](likelihood.sources[signal_source_name],
mu_test)}

for transform_param_name in self.transform_params.keys():
if signal_source_name in transform_param_name:
guess_dict = {self.transform_params[transform_param_name][0]: self.transform_fns[transform_param_name](signal_source_name, likelihood.sources,
mu_test)}
break

for background_source in self.background_source_names:
guess_dict[f'{background_source}_rate_multiplier'] = self.expected_background_counts[background_source]

print(f'guess_dict for observed TS: {guess_dict}')
# Evaluate test statistic
ts_result = test_statistic(mu_test, signal_source_name, guess_dict,
self.transform_params, self.transform_fns, self.transform_fns_inverse)
Expand Down

0 comments on commit f6355ad

Please sign in to comment.