Skip to content

Commit

Permalink
Support for tempalte products (quite fixed for now, will generalise).
Browse files Browse the repository at this point in the history
  • Loading branch information
robertsjames committed Nov 6, 2023
1 parent 9a34a57 commit e5baeae
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions flamedisx/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,67 @@ def simulate(self, n_events, fix_truth=None, full_annotate=False,
return pd.DataFrame(dict(zip(
self.final_dimensions,
self._mh_events_per_bin.get_random(n_events).T)))


@export
class TemplateProductSource(fd.ColumnSource):
"""
"""

def __init__(
self,
templates=None,
axis_names=None,
*args,
**kwargs):
assert(len(templates) == len(axis_names))
self.interp_2d_list = []
self.final_dimensions_list = []
self.mh_list = []

# Get templates, bin_edges, and axis_names
for template, axis in zip(templates, axis_names):
this_template, these_bin_edges = template.histogram, template.bin_edges
assert(len(np.shape(this_template)) == 2)
centers_dim_1 = 0.5 * (these_bin_edges[0][1:] + these_bin_edges[0][:-1])
centers_dim_2 = 0.5 * (these_bin_edges[1][1:] + these_bin_edges[1][:-1])
self.interp_2d_list.append(interp2d(centers_dim_1, centers_dim_2, np.transpose(this_template)))
self.final_dimensions_list.append(axis)
mh = Histdd.from_histogram(this_template, bin_edges=these_bin_edges)
self.mh_list.append(mh)

self.final_dimensions = sum(self.final_dimensions_list, ())

self.mu = fd.np_to_tf(1.)

# Generate a random column name to use to store the diff rates
# of observed events
self.column = (
'template_diff_rate_'
+ ''.join(random.choices(string.ascii_lowercase, k=8)))

super().__init__(*args, **kwargs)

def _annotate(self):
"""Add columns needed in inference to self.data
"""
self.data[self.column] = np.ones_like(len(self.data))
for final_dims, interp_2d in zip(self.final_dimensions_list, self.interp_2d_list):
self.data[self.column] *= np.array([interp_2d(r[self.data.columns.get_loc(final_dims[0])],
r[self.data.columns.get_loc(final_dims[1])])[0]
for r in self.data.itertuples(index=False)])

def simulate(self, n_events, fix_truth=None, full_annotate=False,
keep_padding=False, **params):
"""Simulate n events.
"""
if fix_truth:
raise NotImplementedError("TemplateSource does not yet support fix_truth")
assert isinstance(n_events, (int, float)), \
f"n_events must be an int or float, not {type(n_events)}"

sim_colums = []
for final_dims, mh in zip(self.final_dimensions_list, self.mh_list):
sim_colums.append(pd.DataFrame(dict(zip(
final_dims, mh.get_random(n_events).T))))
return pd.concat(sim_colums, axis=1)

0 comments on commit e5baeae

Please sign in to comment.