Skip to content

Commit

Permalink
Changes to allow multiple components to LL
Browse files Browse the repository at this point in the history
  • Loading branch information
Makayla Trask committed Mar 19, 2024
1 parent a5bc40f commit 1784a08
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 75 deletions.
29 changes: 6 additions & 23 deletions flamedisx/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def __init__(
if isinstance(data, pd.DataFrame) or data is None:
# Only one dataset
data = {DEFAULT_DSETNAME: data}

if not isinstance(list(sources.values())[0], dict):
# Sources only specified for one dataset
assert len(data) == 1, "Specify which sources belong to which dataset"
Expand Down Expand Up @@ -272,31 +271,16 @@ def set_data(self,
UserWarning)
for s in self.sources.values():
s.set_data(None)
return

self.dsetnames = ['SR1','SR3']

self.sources_in_dset = dict()
self.sources_in_dset['SR1'] = []
self.sources_in_dset['SR3'] = []
for source in self.sources:
spl = source.split('_')
component_name = spl[-1]
if component_name == 'SR1':
self.dset_for_source[source] = 'SR1'
self.sources_in_dset['SR1'].append(source)
elif component_name == 'SR3':
self.dset_for_source[source] = 'SR3'
self.sources_in_dset['SR3'].append(source)
return

batch_info = np.zeros((len(self.dsetnames), 3), dtype=int)

for sname, source in self.sources.items():
dname = self.dset_for_source[sname]
if dname not in data.keys():
if dname not in data:
warnings.warn(f"Dataset {dname} not provided in set_data")
continue

# Copy ensures annotations don't clobber
source.set_data(deepcopy(data[dname]))

Expand All @@ -308,13 +292,12 @@ def set_data(self,
# Choose sensible default rate multiplier guesses:
# (1) Assume each free source produces just 1 event
for sname in self.sources:
if self.dset_for_source[sname] not in data.keys():
if self.dset_for_source[sname] not in data:
# This dataset is not being updated, skip
continue

rmname = sname + '_rate_multiplier'
if rmname in self.param_names:
n_expected = self.mu(source_name=sname,dataset_name=self.dset_for_source[sname]).numpy()
n_expected = self.mu(source_name=sname).numpy()
assert n_expected >= 0
self.param_defaults[rmname] = (
self.param_defaults[rmname] / n_expected)
Expand Down Expand Up @@ -909,4 +892,4 @@ def cov_to_std(cov):
"""
std_errs = np.diag(cov) ** 0.5
corr = cov * np.outer(1 / std_errs, 1 / std_errs)
return std_errs, corr
return std_errs, corr
81 changes: 29 additions & 52 deletions flamedisx/non_asymptotic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,31 @@ def guess_value_lower_lim(self, guess_dict, signal_source_name,
if value < 0.1:
guess_dict[key] = 0.1
else:
for transform_param_name in transform_params.keys():
if signal_source_name in transform_param_name:
assert transform_params[transform_param_name][0] == key, "Logic does not hold"
break

if transform_fns_inverse[transform_param_name](signal_source_name, self.likelihood.sources,
assert transform_params[signal_source_name][0] == key, "Logic does not hold"
if transform_fns_inverse[signal_source_name](signal_source_name, self.likelihood.sources,
value) < 0.1:
guess_dict[key] = transform_fns[transform_param_name](signal_source_name, self.likelihood.sources,
guess_dict[key] = transform_fns[signal_source_name](signal_source_name, self.likelihood.sources,
0.1)
#print(f'signal source name: {signal_source_name}')
#print(f'guess_dict for {key}: {guess_dict}')

return guess_dict

def __call__(self, mu_test, signal_source_name, guess_dict,
transform_params, transform_fns, transform_fns_inverse):

# To fix the signal RM in the conditional fit
if f'{signal_source_name}_rate_multiplier' in self.likelihood.param_defaults:
fix_dict = {f'{signal_source_name}_rate_multiplier': mu_test}
else:
for transform_param_name in transform_params.keys():
if signal_source_name in transform_param_name:
fix_dict = {transform_params[transform_param_name][0]: transform_fns[transform_param_name](signal_source_name, self.likelihood.sources,
fix_dict = {transform_params[signal_source_name][0]: transform_fns[signal_source_name](signal_source_name, self.likelihood.sources,
mu_test)}

# source_name = signal_source_name.split("_")[0] was above not sure if will want later

guess_dict_nuisance = guess_dict.copy()
if f'{signal_source_name}_rate_multiplier' in self.likelihood.param_defaults:
guess_dict_nuisance.pop(f'{signal_source_name}_rate_multiplier')
else:
for transform_param_name in transform_params.keys():
if signal_source_name in transform_param_name:
guess_dict_nuisance.pop(transform_params[transform_param_name][0])
break
guess_dict_nuisance.pop(transform_params[signal_source_name][0])

guess_dict = self.guess_value_lower_lim(guess_dict, signal_source_name,
transform_params, transform_fns, transform_fns_inverse)

guess_dict_nuisance = self.guess_value_lower_lim(guess_dict_nuisance, signal_source_name,
transform_params, transform_fns, transform_fns_inverse)

Expand Down Expand Up @@ -331,14 +317,16 @@ def run_routine(self, mus_test=None, save_fits=False,
for background_source in self.background_source_names:
sources[background_source] = self.sources[background_source]
arguments[background_source] = self.arguments[background_source]

# ugly way of dealing with this -- there are more ugly things throughout that will get fixed
# just trying to get a fit first before prettying everything up

temp_signal_source = [signal_source]
for signal_source in self.signal_source_names: # input name i.e. WIMP9
for key in self.sources.keys(): ## actual sources i.e. WIMP9_SR1, WIMP9_SR3
if signal_source in key:
sources[key] = self.sources[key]
arguments[key] = self.arguments[key]
temp_signal_source.append(key)

signal_source = temp_signal_source

# Create likelihood of TemplateSources
if self.ignore_rms is None:
Expand All @@ -360,8 +348,8 @@ def run_routine(self, mus_test=None, save_fits=False,
**kwargs)

rm_bounds = dict()
if signal_source in self.rm_bounds.keys():
rm_bounds[signal_source] = self.rm_bounds[signal_source]
if signal_source[0] in self.rm_bounds.keys():
rm_bounds[signal_source[0]] = self.rm_bounds[signal_source[0]]
for background_source in self.background_source_names:
if background_source in self.rm_bounds.keys():
rm_bounds[background_source] = self.rm_bounds[background_source]
Expand All @@ -378,38 +366,34 @@ def run_routine(self, mus_test=None, save_fits=False,
constraint_extra_args_B_all = []
for i in tqdm(range(self.ntoys), desc='Background-only toys'):
simulate_dict_B, toy_data_B, constraint_extra_args_B = \
self.sample_data_constraints(0., signal_source, likelihood)
self.sample_data_constraints(0., signal_source[1], likelihood)
toy_data_B_all.append(toy_data_B)
constraint_extra_args_B_all.append(constraint_extra_args_B)

if f'{signal_source}_rate_multiplier' in likelihood.param_defaults:
simulate_dict_B.pop(f'{signal_source}_rate_multiplier')
if f'{signal_source[0]}_rate_multiplier' in likelihood.param_defaults:
simulate_dict_B.pop(f'{signal_source[0]}_rate_multiplier')
else:
print(f'signal source in if generate b toys: {signal_source}')
simulate_dict_B.pop(self.transform_params[signal_source][0])
simulate_dict_B.pop(self.transform_params[signal_source[1]][0])

return simulate_dict_B, toy_data_B_all, constraint_extra_args_B_all

these_mus_test = mus_test[signal_source]
print(these_mus_test)
these_mus_test = mus_test[signal_source[0]]
# Loop over signal rate multipliers
for mu_test in tqdm(these_mus_test, desc='Scanning over mus'):
# Case where we want observed test statistics
print(f'in run routine - mu_test: {mu_test}')
if observed_data is not None:
print('in if observed data is not none')
self.get_observed_test_stat(observed_test_stats, observed_data,
mu_test, signal_source, likelihood, save_fits=save_fits)
mu_test, signal_source[1], likelihood, save_fits=save_fits)
# Case where we want test statistic distributions
else:
self.toy_test_statistic_dist(test_stat_dists_SB, test_stat_dists_B,
mu_test, signal_source, likelihood, save_fits=save_fits)
mu_test, signal_source[1], likelihood, save_fits=save_fits)

if observed_data is not None:
observed_test_stats_collection[signal_source] = observed_test_stats
observed_test_stats_collection[signal_source[1]] = observed_test_stats
else:
test_stat_dists_SB_collection[signal_source] = test_stat_dists_SB
test_stat_dists_B_collection[signal_source] = test_stat_dists_B
test_stat_dists_SB_collection[signal_source[1]] = test_stat_dists_SB
test_stat_dists_B_collection[signal_source[1]] = test_stat_dists_B

if observed_data is not None:
return observed_test_stats_collection
Expand Down Expand Up @@ -451,8 +435,7 @@ def sample_data_constraints(self, mu_test, signal_source_name, likelihood):
if f'{signal_source_name}_rate_multiplier' in likelihood.param_defaults:
simulate_dict[f'{signal_source_name}_rate_multiplier'] = mu_test
else:
source_name = signal_source_name.split("_")[0]
simulate_dict[self.transform_params[signal_source_name][0]] = self.transform_fns[signal_source_name](source_name, likelihood.sources,
simulate_dict[self.transform_params[signal_source_name][0]] = self.transform_fns[signal_source_name](signal_source_name, likelihood.sources,
mu_test)

toy_data = likelihood.simulate(**simulate_dict)
Expand Down Expand Up @@ -504,8 +487,7 @@ def toy_test_statistic_dist(self, test_stat_dists_SB, test_stat_dists_B,
if f'{signal_source_name}_rate_multiplier' in likelihood.param_defaults:
guess_dict_B[f'{signal_source_name}_rate_multiplier'] = 0.
else:
source_name = signal_source_name.split("_")[0]
guess_dict_B[self.transform_params[signal_source_name][0]] = self.transform_fns[signal_source_name](source_name, likelihood.sources,
guess_dict_B[self.transform_params[signal_source_name][0]] = self.transform_fns[signal_source_name](signal_source_name, likelihood.sources,
0.)

toy_data_B = self.toy_data_B[toy+(self.toy_batch*self.ntoys)]
Expand Down Expand Up @@ -553,26 +535,21 @@ def get_observed_test_stat(self, observed_test_stats, observed_data,

# Set data
likelihood.set_data(observed_data)

# Create test statistic
test_statistic = self.test_statistic(likelihood)

print(f'mu test in get_obs: {mu_test}')

# Guesses for fit
if f'{signal_source_name}_rate_multiplier' in likelihood.param_defaults:
guess_dict = {f'{signal_source_name}_rate_multiplier': mu_test}
else:
for transform_param_name in self.transform_params.keys():
if signal_source_name in transform_param_name:
guess_dict = {self.transform_params[transform_param_name][0]: self.transform_fns[transform_param_name](signal_source_name, likelihood.sources,
mu_test)}
break
guess_dict = {self.transform_params[transform_param_name][0]: self.transform_fns[transform_param_name](signal_source_name,
likelihood.sources, mu_test)}
break # only want sigma_ratio in there once since it represents a combination of the signal source for each data set

for background_source in self.background_source_names:
guess_dict[f'{background_source}_rate_multiplier'] = self.expected_background_counts[background_source]

print(f'guess_dict for observed TS: {guess_dict}')
# Evaluate test statistic
ts_result = test_statistic(mu_test, signal_source_name, guess_dict,
self.transform_params, self.transform_fns, self.transform_fns_inverse)
Expand Down

0 comments on commit 1784a08

Please sign in to comment.