diff --git a/src/pinefarm/external/nnlojet/runcardgen.py b/src/pinefarm/external/nnlojet/runcardgen.py index b7aeb1d..17ecdcf 100755 --- a/src/pinefarm/external/nnlojet/runcardgen.py +++ b/src/pinefarm/external/nnlojet/runcardgen.py @@ -10,6 +10,7 @@ from functools import cached_property from pathlib import Path +import numpy as np from yaml import safe_load logger = logging.getLogger(__name__) @@ -27,6 +28,23 @@ """ INDT = " " # indentation +# Constants for the combine.ini config file +COMBINE_OUTPUT_FOLDER = "combined" +COMBINE_HEADER = f""" +[Paths] +raw_dir = . +out_dir = {COMBINE_OUTPUT_FOLDER} + +""" +COMBINE_FOOTER = """ + +[Options] +recursive = True +weights = True +trim = (3.5, 0.01) +k-scan = (None, 3, 0.7) +""" + @dataclass class Histogram: @@ -321,3 +339,104 @@ def generate_runcard( logger.info(f"Runcard written to {runcard_path}") return runcard_path + + +## combine.ini generation + + +def _channel_selection(metadata, channels=None): + """Generate a selection of channels compatible with the metadata. + + Run over all possible channels in the metadata and add them + to the combine script in the right combination whenever they are + selected in the arguments of this function. + + Parameters + ---------- + metadata: YamlLOJET + information from th e pinecard + channels: list(str) + list of channels to be run + """ + all_levels = {i.split("_")[0] for i in metadata.channels.keys()} + + lo = ["LO"] + nlo = ["R", "V"] + nnlo = ["RR", "RV", "VV", "RRa", "RRb"] + + def is_allowed(l): + """Return false if the given level is not allowed.""" + if channels is None: + return True + if l in channels: + return True + if l in ("RRa", "RRb") and ("RR" in channels): + return True + return False + + ret = {} + # Now go over every level and include it in the combine.ini + # whenever it is both in channels and in the metadata + if add_lo := all_levels.intersection(lo): + if all(is_allowed(i) for i in add_lo): + ret["LO"] = list(add_lo) + if add_nlo := all_levels.intersection(nlo): + tmp = list(add_nlo.union(add_lo)) + if all(is_allowed(i) for i in tmp): + ret["NLO"] = tmp + if add_nnlo := all_levels.intersection(nnlo): + tmp = list(add_nnlo.union(add_nlo).union(add_lo)) + if all(is_allowed(i) for i in tmp): + ret["NNLO"] = tmp + # Add a exclusive NNLO level for debugging + if all(is_allowed(i) for i in add_nnlo): + ret["exclusive_nnlo"] = list(add_nnlo) + + if not ret: + raise ValueError(f"No channel {channels} found in the pinecard") + + return ret + + +def _generate_channel_merging(metadata, combinations): + """Define how subchannels are to be merged. + + Looking at metadata and the list of allowed channels, prepare + [Parts], [Merge] and [Final]. + """ + allowed_levels = set(np.concatenate(list(combinations.values()))) + merge_dict = metadata.active_channels(active_channels=allowed_levels) + + parts_list = [] + merge_list = [] + for level_name, channel_list in merge_dict.items(): + parts_list.append("\n".join(channel_list)) + merge_list.append(f"{level_name} = " + " + ".join(channel_list)) + + ret = "\n[Parts]\n" + "\n".join(parts_list) + ret += "\n\n[Merge]\n" + "\n".join(merge_list) + ret += "\n\n[Final]" + + for order, levels in combinations.items(): + if all(l in merge_dict for l in levels): + ret += f"\n{order} = " + " + ".join(levels) + + return ret + + +def generate_combine_ini(metadata, channels, output=Path(".")): + """Generate a NNLOJET combine config file.""" + # Initialize the file + cini_text = COMBINE_HEADER + + # Define the list of observables + obs_list = "\n".join([i.name for i in metadata.histograms]) + cini_text += f"\n[Observables]\ncross\n{obs_list}\n" + + # Define which (nnlojet) channels are neded and how they are merged + combinations = _channel_selection(metadata, channels) + cini_text += _generate_channel_merging(metadata, combinations) + + cini_text += COMBINE_FOOTER + + (output / "combine.ini").write_text(cini_text) diff --git a/src/pinefarm/external/nnlojet/runner.py b/src/pinefarm/external/nnlojet/runner.py index 7d43b1a..329714d 100644 --- a/src/pinefarm/external/nnlojet/runner.py +++ b/src/pinefarm/external/nnlojet/runner.py @@ -3,7 +3,7 @@ from yaml import safe_load from .. import interface -from .runcardgen import YamlLOJET, generate_runcard +from .runcardgen import YamlLOJET, generate_combine_ini, generate_runcard # Reasonable default for warmup and production for DY _DEFAULTS = { @@ -71,6 +71,7 @@ def preparation(self): iterations=nit, ) + generate_combine_ini(pinedata, channels, self.dest) return True def run(self):