Skip to content

Commit

Permalink
add autogeneration of combine.ini file
Browse files Browse the repository at this point in the history
  • Loading branch information
scarlehoff committed Mar 26, 2024
1 parent e35c305 commit 1b5659c
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 1 deletion.
119 changes: 119 additions & 0 deletions src/pinefarm/external/nnlojet/runcardgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from functools import cached_property
from pathlib import Path

import numpy as np
from yaml import safe_load

logger = logging.getLogger(__name__)
Expand All @@ -27,6 +28,23 @@
"""
INDT = " " # indentation

# Constants for the combine.ini config file
COMBINE_OUTPUT_FOLDER = "combined"
COMBINE_HEADER = f"""
[Paths]
raw_dir = .
out_dir = {COMBINE_OUTPUT_FOLDER}
"""
COMBINE_FOOTER = """
[Options]
recursive = True
weights = True
trim = (3.5, 0.01)
k-scan = (None, 3, 0.7)
"""


@dataclass
class Histogram:
Expand Down Expand Up @@ -321,3 +339,104 @@ def generate_runcard(
logger.info(f"Runcard written to {runcard_path}")

return runcard_path


## combine.ini generation


def _channel_selection(metadata, channels=None):
"""Generate a selection of channels compatible with the metadata.
Run over all possible channels in the metadata and add them
to the combine script in the right combination whenever they are
selected in the arguments of this function.
Parameters
----------
metadata: YamlLOJET
information from th e pinecard
channels: list(str)
list of channels to be run
"""
all_levels = {i.split("_")[0] for i in metadata.channels.keys()}

lo = ["LO"]
nlo = ["R", "V"]
nnlo = ["RR", "RV", "VV", "RRa", "RRb"]

def is_allowed(l):
"""Return false if the given level is not allowed."""
if channels is None:
return True
if l in channels:
return True
if l in ("RRa", "RRb") and ("RR" in channels):
return True
return False

ret = {}
# Now go over every level and include it in the combine.ini
# whenever it is both in channels and in the metadata
if add_lo := all_levels.intersection(lo):
if all(is_allowed(i) for i in add_lo):
ret["LO"] = list(add_lo)
if add_nlo := all_levels.intersection(nlo):
tmp = list(add_nlo.union(add_lo))
if all(is_allowed(i) for i in tmp):
ret["NLO"] = tmp
if add_nnlo := all_levels.intersection(nnlo):
tmp = list(add_nnlo.union(add_nlo).union(add_lo))
if all(is_allowed(i) for i in tmp):
ret["NNLO"] = tmp
# Add a exclusive NNLO level for debugging
if all(is_allowed(i) for i in add_nnlo):
ret["exclusive_nnlo"] = list(add_nnlo)

if not ret:
raise ValueError(f"No channel {channels} found in the pinecard")

return ret


def _generate_channel_merging(metadata, combinations):
"""Define how subchannels are to be merged.
Looking at metadata and the list of allowed channels, prepare
[Parts], [Merge] and [Final].
"""
allowed_levels = set(np.concatenate(list(combinations.values())))
merge_dict = metadata.active_channels(active_channels=allowed_levels)

parts_list = []
merge_list = []
for level_name, channel_list in merge_dict.items():
parts_list.append("\n".join(channel_list))
merge_list.append(f"{level_name} = " + " + ".join(channel_list))

ret = "\n[Parts]\n" + "\n".join(parts_list)
ret += "\n\n[Merge]\n" + "\n".join(merge_list)
ret += "\n\n[Final]"

for order, levels in combinations.items():
if all(l in merge_dict for l in levels):
ret += f"\n{order} = " + " + ".join(levels)

return ret


def generate_combine_ini(metadata, channels, output=Path(".")):
"""Generate a NNLOJET combine config file."""
# Initialize the file
cini_text = COMBINE_HEADER

# Define the list of observables
obs_list = "\n".join([i.name for i in metadata.histograms])
cini_text += f"\n[Observables]\ncross\n{obs_list}\n"

# Define which (nnlojet) channels are neded and how they are merged
combinations = _channel_selection(metadata, channels)
cini_text += _generate_channel_merging(metadata, combinations)

cini_text += COMBINE_FOOTER

(output / "combine.ini").write_text(cini_text)
3 changes: 2 additions & 1 deletion src/pinefarm/external/nnlojet/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from yaml import safe_load

from .. import interface
from .runcardgen import YamlLOJET, generate_runcard
from .runcardgen import YamlLOJET, generate_combine_ini, generate_runcard

# Reasonable default for warmup and production for DY
_DEFAULTS = {
Expand Down Expand Up @@ -71,6 +71,7 @@ def preparation(self):
iterations=nit,
)

generate_combine_ini(pinedata, channels, self.dest)
return True

def run(self):
Expand Down

0 comments on commit 1b5659c

Please sign in to comment.