Skip to content

Commit

Permalink
Template interpolation for all dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleAalbers committed Mar 22, 2024
1 parent 8fc2d2e commit 908d334
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
46 changes: 27 additions & 19 deletions flamedisx/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from multihist import Histdd
import pandas as pd
from scipy.interpolate import interp2d
import scipy.interpolate

import numpy as np

Expand All @@ -26,17 +26,19 @@ class TemplateSource(fd.ColumnSource):
If None, get this info from template.
- events_per_bin: set to True if template specifies expected events per
bin, rather than differential rate.
- interpolate: if True, differential rates are interpolated linearly
between the bin centers.
For other arguments, see flamedisx.source.Source
"""

def __init__(
self,
template,
interp_2d=False,
bin_edges=None,
axis_names=None,
events_per_bin=False,
interpolate=False,
*args,
**kwargs):
# Get template, bin_edges, and axis_names
Expand All @@ -62,17 +64,6 @@ def __init__(
raise ValueError("Axis names missing or mismatched")
self.final_dimensions = axis_names

if interp_2d:
assert len(self.final_dimensions) == 2, "Interpolation only supported for 2D histogram!"
centers_dim_1 = 0.5 * (bin_edges[0][1:] + bin_edges[0][:-1])
centers_dim_2 = 0.5 * (bin_edges[1][1:] + bin_edges[1][:-1])
self.interp_2d = interp2d(
centers_dim_1, centers_dim_2,
np.transpose(template),
kind='linear')
else:
self.interp_2d = None

# Build a diff rate and events/bin multihist from the template
_mh = Histdd.from_histogram(template, bin_edges=bin_edges)
if events_per_bin:
Expand All @@ -84,6 +75,18 @@ def __init__(

self.mu = fd.np_to_tf(self._mh_events_per_bin.n)

if interpolate:
# Build an interpolator for the differential rate
bin_centers = [
0.5 * (edges[1:] + edges[:-1])
for edges in bin_edges]
self._interpolator = scipy.interpolate.RegularGridInterpolator(
points=tuple(bin_centers),
values=self._mh_diff_rate.histogram,
method='linear')
else:
self._interpolator = None

# Generate a random column name to use to store the diff rates
# of observed events
self.column = (
Expand All @@ -95,13 +98,18 @@ def __init__(
def _annotate(self):
"""Add columns needed in inference to self.data
"""
if self.interp_2d is not None:
self.data[self.column] = np.array([self.interp_2d(r[self.data.columns.get_loc(self.final_dimensions[0])],
r[self.data.columns.get_loc(self.final_dimensions[1])])[0]
for r in self.data.itertuples(index=False)])
# (n_dims, n_points) array of input data
data = np.stack([
self.data[dim].values
for dim in self.final_dimensions])

if self._interpolator:
# transpose since RegularGridInterpolator expects (n_points, n_dims)
result = self._interpolator(data.T)
else:
self.data[self.column] = self._mh_diff_rate.lookup(
*[self.data[x] for x in self.final_dimensions])
result = self._mh_diff_rate.lookup(*data)

self.data[self.column] = result

def simulate(self, n_events, fix_truth=None, full_annotate=False,
keep_padding=False, **params):
Expand Down
8 changes: 7 additions & 1 deletion tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_template_interpolation():
data = pd.DataFrame({'x': [0.5, 1.5, 3.5], 'y': [0.5, 0.5, 2.5]})

# Interpolate using flamedisx
s = fd.TemplateSource(mh, interp_2d=True)
s = fd.TemplateSource(mh, interpolate=True)
s.set_data(data)
dr_flamedisx = s.batched_differential_rate()

Expand All @@ -60,3 +60,9 @@ def test_template_interpolation():
values=mh.histogram,
method='linear')(z)
assert np.allclose(dr_flamedisx, dr_itp)

# With interpolation turned off, flamedisx just looks up the diff rates
s = fd.TemplateSource(mh, interpolate=False)
s.set_data(data)
dr_flamedisx_noitp = s.batched_differential_rate()
assert np.allclose(dr_flamedisx_noitp, mh.lookup(data['x'], data['y']))

0 comments on commit 908d334

Please sign in to comment.