diff --git a/docs/source/fitting_code/running.md b/docs/source/fitting_code/running.md index 59c36345..ff9b38b4 100644 --- a/docs/source/fitting_code/running.md +++ b/docs/source/fitting_code/running.md @@ -221,6 +221,19 @@ One is free to set custom attributes in the constructor. The coefficient values are accesible via ``coefficient_values`` in the ``compute_chi2`` method. In order for the external chi2 to work, it is important one does not change the name of the ``compute_chi2`` method! +### Adding RG evolution +Renormalisation group evolution can be turned on in the fit by adding the following to the runcard. + +```yaml +rge: + init_scale: 5000.0 + obs_scale: dynamic # float or "dynamic" + smeft_accuracy: integrate # options: integrate, leadinglog + yukawa: top # options: top, full or none + adm_QCD: False # if true, the EW couplings are set to zero + rg_matrix: +``` + ## Running a fit with NS To run a fiy using Nested Sampling use the command ```bash diff --git a/docs/source/report/running.md b/docs/source/report/running.md index a07a5993..a6ed863a 100644 --- a/docs/source/report/running.md +++ b/docs/source/report/running.md @@ -168,6 +168,16 @@ fisher: figsize: [11, 15] # figure size title: true # if True display the fit label as title + plot: + summary_only: True # if True display only the fisher information per dataset group. If False will show the fine grained dataset per dataset + figsize: [11, 15] # figure size + title: true # if True display the fit label as title + column_names: # list of column names to be displayed, default is all + - group_1: "$\\rm group\\:1$" + - tt13: "$t\\bar{t}$" + - ... + together: ["fit_1", "fit_2"] # list of result IDs to be plotted together + ``` Finally the user has to specify two dictionaries where the informaions about diff --git a/src/smefit/analyze/fisher.py b/src/smefit/analyze/fisher.py index 098d4c36..2d813ece 100644 --- a/src/smefit/analyze/fisher.py +++ b/src/smefit/analyze/fisher.py @@ -1,8 +1,12 @@ # -*- coding: utf-8 -*- +import matplotlib as mpl +import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np import pandas as pd from matplotlib import colors +from matplotlib.legend_handler import HandlerPatch +from matplotlib.patches import Polygon from mpl_toolkits.axes_grid1 import make_axes_locatable from rich.progress import track @@ -10,6 +14,37 @@ from .pca import impose_constrain +class HandlerTriangle(HandlerPatch): + def create_artists( + self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans + ): + center = (width / 2 - xdescent, height / 2 - ydescent) + size = min(width, height) / 2 + # Define the lower-left triangle vertices + + if orig_handle.xy[0, 0] < orig_handle.xy[1, 0]: + vertices = [ + (center[0] - size, center[1] - size), # Bottom-left + (center[0] + size, center[1] - size), # Bottom-right + (center[0] - size, center[1] + size), # Top-left + ] + else: + # (upper-right) + vertices = [ + (center[0] + size, center[1] + size), # Top-right + (center[0] - size, center[1] + size), # Top-left + (center[0] + size, center[1] - size), # Bottom-right + ] + p = mpatches.Polygon( + vertices, + closed=True, + facecolor=orig_handle.get_facecolor(), + edgecolor=orig_handle.get_edgecolor(), + ) + p.set_transform(trans) + return [p] + + class FisherCalculator: """Computes and writes the Fisher information table, and plots heat map. @@ -53,7 +88,6 @@ def compute_linear(self): fisher_tab = [] cnt = 0 for ndat in self.datasets.NdataExp: - fisher_row = np.zeros(self.free_parameters.size) idxs = slice(cnt, cnt + ndat) sigma = self.new_LinearCorrections[:, idxs] fisher_row = np.diag(sigma @ self.datasets.InvCovMat[idxs, idxs] @ sigma.T) @@ -317,42 +351,235 @@ def color(value, thr_val=10): ) return L - def plot( + @staticmethod + def unify_fishers(df, df_other): + + if df_other is None or df is None: + return None + + # Get the union of row and column indices + all_rows = df.index.union(df_other.index) + all_columns = df.columns.union(df_other.columns) + + # Reindex both DataFrames to have the same rows and columns + df = df.reindex(index=all_rows, columns=all_columns, fill_value=0) + df_other = df_other.reindex(index=all_rows, columns=all_columns, fill_value=0) + + return df, df_other + + @staticmethod + def set_ticks(ax, yticks, xticks, latex_names, x_labels): + ax.set_yticks(yticks, labels=latex_names[::-1], fontsize=15) + ax.set_xticks( + xticks, + labels=x_labels, + rotation=90, + fontsize=15, + ) + ax.xaxis.set_ticks_position("top") + ax.tick_params(which="major", top=False, bottom=False, left=False) + ax.set_xticks(xticks - 0.5, minor=True) + ax.set_yticks(yticks - 0.5, minor=True) + ax.tick_params(which="minor", bottom=False) + ax.grid(visible=True, which="minor", alpha=0.2) + + @staticmethod + def plot_values(ax, dfs, cmap, norm, labels=None): + """ + Plot the values of the Fisher information. + + Parameters + ---------- + ax: matplotlib.axes.Axes + axes object + dfs: list + list of pandas.DataFrame + cmap: matplotlib.colors.LinearSegmentedColormap + colour map + norm: matplotlib.colors.BoundaryNorm + normalisation of colorbar + labels: list, optional + label elements for legend + """ + + df_1 = dfs[0] + df_2 = dfs[1] if len(dfs) > 1 else None + cols, rows = df_1.shape + + # Initialize the delta shift for text positioning + delta_shift = 0 + + for i, row in enumerate(df_1.values.T): + for j, elem_1 in enumerate(row): + + # start filling from the top left corner + x, y = j, rows - 1 - i + ec_1 = "black" + + # if two fishers must be plotted together + if df_2 is not None: + + elem_2 = df_2.values.T[i, j] + + # move position numbers + delta_shift = 0.2 + + # highlight operators that exist in one but not the other + ec_1 = "C1" if elem_2 == 0 and elem_1 > 0 else "black" + + if elem_2 > 0: + ax.text( + x + delta_shift, + y + delta_shift, + f"{elem_2:.1f}", + va="center", + ha="center", + fontsize=10, + ) + + # Create a triangle patch for the second element + triangle2 = Polygon( + [ + [x + 0.5, y - 0.5], + [x + 0.5, y + 0.5], + [x - 0.5, y + 0.5], + ], + closed=True, + facecolor=cmap(norm(elem_2)), + edgecolor="black", + ) + ax.add_patch(triangle2) + + if elem_1 > 0: + + ax.text( + x - delta_shift, + y - delta_shift, + f"{elem_1:.1f}", + va="center", + ha="center", + fontsize=10, + ) + if df_2 is not None: + + # Create a triangle patch for the first element + triangle1 = Polygon( + [ + [x - 0.5, y - 0.5], + [x + 0.5, y - 0.5], + [x - 0.5, y + 0.5], + ], + closed=True, + facecolor=cmap(norm(elem_1)), + edgecolor=ec_1, + ) + ax.add_patch(triangle1) + + # Create legend elements for the patches + legend_elements = [ + mpatches.Polygon( + [[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5]], + closed=True, + fc="none", + edgecolor="black", + label=labels[0], + ), + mpatches.Polygon( + [[0.5, -0.5], [0.5, 0.5], [0.5, 0.5]], + closed=True, + fc="none", + edgecolor="black", + label=labels[1], + ), + ] + # Add the legend to the plot + ax.legend( + handles=legend_elements, + loc="upper center", + fontsize=25, + frameon=False, + ncol=2, + handler_map={mpatches.Polygon: HandlerTriangle()}, + bbox_to_anchor=(0.5, -0.02), + ) + else: + # Create a rectangle patch for the first element + rectangle = Polygon( + [ + [x - 0.5, y - 0.5], + [x + 0.5, y - 0.5], + [x + 0.5, y + 0.5], + [x - 0.5, y + 0.5], + ], + closed=True, + ec="grey", + color=cmap(norm(elem_1)), + ) + ax.add_patch(rectangle) + + # Set the x and y limits of the plot + ax.set_xlim(0, cols - 0.5) + ax.set_ylim(0, rows - 0.5) + # Set the aspect ratio of the plot to be equal + ax.set_aspect("equal", adjustable="box") + + def plot_heatmap( self, latex_names, fig_name, title=None, + other=None, summary_only=True, figsize=(11, 15), + labels=None, + column_names=None, ): - """Plot the heat map of Fisher table. - Parameters - ---------- - latex_names : list - list of coefficients latex names - fig_name: str - figure path - summary_only: - if True plot the fisher grouped per datsets, - else the fine grained dataset per dataset - figsize : tuple - figure size - title: str, None - plot title - """ - if summary_only: - fisher_df = self.summary_table - quad_fisher_df = self.summary_HOtable - else: - fisher_df = self.lin_fisher - quad_fisher_df = self.quad_fisher + fisher_df = self.summary_table if summary_only else self.lin_fisher + quad_fisher_df = self.summary_HOtable if summary_only else self.quad_fisher + + if other is not None: + + fisher_df_other = other.summary_table if summary_only else other.lin_fisher + quad_fisher_df_other = ( + other.summary_HOtable if summary_only else other.quad_fisher + ) + # unify the fisher tables and fill missing values by zeros + fisher_dfs = self.unify_fishers(fisher_df, fisher_df_other) + + # reshuffle the tables according to the latex names ordering + fisher_dfs = [ + fisher[latex_names.index.get_level_values(level=1)] + for fisher in fisher_dfs + ] + + if quad_fisher_df is not None: + quad_fisher_dfs = self.unify_fishers( + quad_fisher_df, quad_fisher_df_other + ) + + # reshuffle the tables according to the latex names ordering + quad_fisher_dfs = [ + fisher[latex_names.index.get_level_values(level=1)] + for fisher in quad_fisher_dfs + ] - fig = plt.figure(figsize=figsize) - if quad_fisher_df is not None: - ax = fig.add_subplot(121) else: - ax = plt.gca() + fisher_dfs = [fisher_df[latex_names.index.get_level_values(level=1)]] + if quad_fisher_df is not None: + quad_fisher_dfs = [ + quad_fisher_df[latex_names.index.get_level_values(level=1)] + ] + + # reshuffle column name ordering + if column_names is not None: + custom_ordering = [list(column.keys())[0] for column in column_names] + fisher_dfs = [fisher_df.loc[custom_ordering] for fisher_df in fisher_dfs] + x_labels = [list(column.values())[0] for column in column_names] + else: + x_labels = [ + f"\\rm{{{name}}}".replace("_", "\\_") for name in fisher_df.index + ] # colour map cmap_full = plt.get_cmap("Blues") @@ -362,56 +589,43 @@ def plot( ) norm = colors.BoundaryNorm(np.arange(110, step=10), cmap.N) - # ticks - yticks = np.arange(fisher_df.shape[1]) - xticks = np.arange(fisher_df.shape[0]) - x_labels = [f"\\rm{{{name}}}".replace("_", "\\_") for name in fisher_df.index] - - def set_ticks(ax): - ax.set_yticks(yticks, labels=latex_names, fontsize=15) - ax.set_xticks( - xticks, - labels=x_labels, - rotation=90, - fontsize=15, - ) - ax.tick_params(which="major", top=False, bottom=False, left=False) - # minor grid - ax.set_xticks(xticks - 0.5, minor=True) - ax.set_yticks(yticks - 0.5, minor=True) - ax.tick_params(which="minor", bottom=False) - ax.grid(visible=True, which="minor", alpha=0.2) - - def plot_values(ax, df): - for i, row in enumerate(df.values.T): - for j, elem in enumerate(row): - if elem > 0: - ax.text( - j, - i, - f"{elem:.1f}", - va="center", - ha="center", - fontsize=8, - ) + fig = plt.figure(figsize=figsize) + if quad_fisher_df is not None: + ax = fig.add_subplot(121) + else: + ax = plt.gca() - cax = ax.matshow(fisher_df.values.T, cmap=cmap, norm=norm) - plot_values(ax, fisher_df) - set_ticks(ax) + self.plot_values(ax, fisher_dfs, cmap, norm, labels) + + self.set_ticks( + ax, + np.arange(fisher_dfs[0].shape[1]), + np.arange(fisher_dfs[0].shape[0]), + latex_names, + x_labels, + ) ax.set_title(r"\rm Linear", fontsize=20, y=-0.08) cax1 = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.5) - colour_bar = fig.colorbar(cax, cax=cax1) + colour_bar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax1) if quad_fisher_df is not None: ax = fig.add_subplot(122) - cax = ax.matshow(quad_fisher_df.values.T, cmap=cmap, norm=norm) - plot_values(ax, quad_fisher_df) - set_ticks(ax) + self.plot_values(ax, quad_fisher_dfs, cmap, norm, labels) + + self.set_ticks( + ax, + np.arange(quad_fisher_dfs[0].shape[1]), + np.arange(quad_fisher_dfs[0].shape[0]), + latex_names, + x_labels, + ) ax.set_title(r"\rm Quadratic", fontsize=20, y=-0.08) - cax1 = make_axes_locatable(ax).append_axes("right", size="10%", pad=0.1) - colour_bar = fig.colorbar(cax, cax=cax1) + cax1 = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.5) + colour_bar = fig.colorbar( + mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax1 + ) - fig.subplots_adjust(top=0.85) + fig.subplots_adjust(top=0.9) colour_bar.set_label( r"${\rm Normalized\ Value}$", diff --git a/src/smefit/analyze/report.py b/src/smefit/analyze/report.py index f9f3236a..5c4eb16a 100644 --- a/src/smefit/analyze/report.py +++ b/src/smefit/analyze/report.py @@ -490,12 +490,7 @@ def pca( self._append_section("PCA", figs=figs_list, links=links_list) def fisher( - self, - norm="coeff", - summary_only=True, - plot=None, - fit_list=None, - log=False, + self, norm="coeff", summary_only=True, plot=None, fit_list=None, log=False ): """Fisher information table and plots runner. @@ -523,6 +518,7 @@ def fisher( else: fit_list = self.fits + fishers = {} for fit in fit_list: compute_quad = fit.config["use_quad"] fisher_cal = FisherCalculator(fit.coefficients, fit.datasets, compute_quad) @@ -533,6 +529,7 @@ def fisher( fisher_cal.summary_table = fisher_cal.groupby_data( fisher_cal.lin_fisher, self.data_info, norm, log ) + fishers[fit.name] = fisher_cal # if necessary compute the quadratic Fisher if compute_quad: @@ -546,27 +543,44 @@ def fisher( fisher_cal.quad_fisher, self.data_info, norm, log ) - # Write down the table in latex - free_coeff_config = self.coeff_info.loc[ - :, fit.coefficients.free_parameters.index - ] compile_tex( self.report, - fisher_cal.write_grouped( - free_coeff_config, self.data_info, summary_only - ), + fisher_cal.write_grouped(self.coeff_info, self.data_info, summary_only), f"fisher_{fit.name}", ) links_list.append((f"fisher_{fit.name}", f"Table {fit.label}")) if plot is not None: fit_plot = copy.deepcopy(plot) + fit_plot.pop("together", None) title = fit.label if fit_plot.pop("title") else None - fisher_cal.plot( - free_coeff_config, + fisher_cal.plot_heatmap( + self.coeff_info, f"{self.report}/fisher_heatmap_{fit.name}", title=title, **fit_plot, ) figs_list.append(f"fisher_heatmap_{fit.name}") + + # plot both fishers + if plot.get("together", False): + fisher_1 = fishers[plot["together"][0]] + fisher_2 = fishers[plot["together"][1]] + fit_plot = copy.deepcopy(plot) + fit_plot.pop("together") + + # show title of last fit + title = fit.label if fit_plot.pop("title") else None + + # make heatmap of fisher_1 and fisher_2 + fisher_2.plot_heatmap( + self.coeff_info, + f"{self.report}/fisher_heatmap_both", + title=title, + other=fisher_1, + labels=[fit.label for fit in self.fits], + **fit_plot, + ) + figs_list.append(f"fisher_heatmap_both") + self._append_section("Fisher", figs=figs_list, links=links_list) diff --git a/src/smefit/fit_manager.py b/src/smefit/fit_manager.py index 49853f17..b89ced01 100644 --- a/src/smefit/fit_manager.py +++ b/src/smefit/fit_manager.py @@ -1,6 +1,8 @@ # -*- coding: utf-8 -*- import json +import pickle +import jax.numpy as jnp import numpy as np import pandas as pd import yaml @@ -54,6 +56,7 @@ def __init__(self, path, name, label=None): self.has_posterior = self.config.get("has_posterior", True) self.results = None self.datasets = None + self.rgemat = None def __repr__(self): return self.name @@ -73,6 +76,16 @@ def load_results(self): with open(f"{self.path}/{self.name}/{file}.json", encoding="utf-8") as f: results = json.load(f) + # load the rge matrix in the result dir if it exists + try: + with open(f"{self.path}/{self.name}/rge_matrix.pkl", "rb") as f: + rgemats = pickle.load(f) + self.operators_to_keep = {op: {} for op in rgemats[0].index} + self.rgemat = jnp.stack([rgemat.values for rgemat in rgemats]) + + except FileNotFoundError: + print("No RGE matrix found in the result folder, skipping...") + # if the posterior is from single parameter fits # then each distribution might have a different number of samples is_single_param = results.get("single_parameter_fits", False) @@ -113,10 +126,13 @@ def load_configuration(self): def load_datasets(self): """Load all datasets.""" + self.datasets = load_datasets( self.config["data_path"], self.config["datasets"], - self.config["coefficients"], + self.config["coefficients"] + if self.rgemat is None + else self.operators_to_keep, self.config["use_quad"], self.config["use_theory_covmat"], False, # t0 is not used here because in the report we look at the experimental chi2 @@ -126,6 +142,7 @@ def load_datasets(self): self.config.get("rot_to_fit_basis", None), self.config.get("uv_couplings", False), self.config.get("external_chi2", False), + rgemat=self.rgemat, ) @property diff --git a/src/smefit/loader.py b/src/smefit/loader.py index ae26204f..e1430d6c 100644 --- a/src/smefit/loader.py +++ b/src/smefit/loader.py @@ -300,6 +300,7 @@ def load_theory( lin_dict = {} # save sm prediction at the chosen perturbative order + sm = np.array(raw_th_data[order]["SM"]) # split corrections into a linear and quadratic dict @@ -530,10 +531,7 @@ def construct_corrections_matrix_linear( cnt += n_dat if rgemat is not None: - if len(rgemat.shape) == 3: # dynamic scale, scale is datapoint specific - corr_values = jnp.einsum("ij, ijk -> ik", corr_values, rgemat) - else: # fixed scale so same rgemat for all datapoints - corr_values = jnp.einsum("ij, jk -> ik", corr_values, rgemat) + corr_values = jnp.einsum("ij, ijk -> ik", corr_values, rgemat) return corr_values @@ -553,6 +551,10 @@ def construct_corrections_matrix_quadratic( sorted_keys: numpy.ndarray list of sorted operator corrections, shape=(n rg generated coeff,) or shape=(n original coeff,) in the absence of rgemat + rgemat: numpy.ndarray, optional + solution matrix of the RGE, shape=(k, l, m) with k the number of datapoints, + l the number of generated coefficients under the RG and m the number of + original |EFT| coefficients specified in the runcard. Returns ------- @@ -579,12 +581,7 @@ def construct_corrections_matrix_quadratic( cnt += n_dat if rgemat is not None: - if len(rgemat.shape) == 3: # dynamic scale, scale is datapoint specific - corr_values = jnp.einsum( - "ijk, ijl, ikr -> ilr", corr_values, rgemat, rgemat - ) - else: # fixed scale so same rgemat for all datapoints - corr_values = jnp.einsum("ijk, jl, kr -> ilr", corr_values, rgemat, rgemat) + corr_values = jnp.einsum("ijk, ijl, ikr -> ilr", corr_values, rgemat, rgemat) return corr_values diff --git a/src/smefit/optimize/__init__.py b/src/smefit/optimize/__init__.py index 285ec5f8..0942b31e 100644 --- a/src/smefit/optimize/__init__.py +++ b/src/smefit/optimize/__init__.py @@ -7,7 +7,7 @@ from rich.style import Style from rich.table import Table -from smefit.rge import RGE +from smefit.rge.rge import RGE from .. import chi2, log from ..coefficients import CoefficientManager diff --git a/src/smefit/optimize/analytic.py b/src/smefit/optimize/analytic.py index a78304ab..ca6a93b1 100644 --- a/src/smefit/optimize/analytic.py +++ b/src/smefit/optimize/analytic.py @@ -4,7 +4,7 @@ from rich.style import Style from rich.table import Table -from smefit.rge import load_rge_matrix +from smefit.rge.rge import load_rge_matrix from .. import chi2, log from ..analyze.pca import impose_constrain @@ -92,6 +92,8 @@ def from_dict(cls, config): config["datasets"], config.get("theory_path", None), cutoff_scale, + config.get("result_path", None), + config.get("result_ID", None), ) _logger.info("The operators generated by the RGE are: ") _logger.info(list(operators_to_keep.keys())) diff --git a/src/smefit/optimize/ultranest.py b/src/smefit/optimize/ultranest.py index 1bb2e30d..a8b3be42 100644 --- a/src/smefit/optimize/ultranest.py +++ b/src/smefit/optimize/ultranest.py @@ -10,7 +10,7 @@ from rich.table import Table from ultranest import stepsampler -from smefit.rge import load_rge_matrix +from smefit.rge.rge import load_rge_matrix from .. import chi2, log from ..coefficients import CoefficientManager @@ -155,6 +155,8 @@ def from_dict(cls, config): config["datasets"], config.get("theory_path", None), cutoff_scale, + config.get("result_path", None), + config.get("result_ID", None), ) _logger.info("The operators generated by the RGE are: ") _logger.info(list(operators_to_keep.keys())) diff --git a/src/smefit/rge/__init__.py b/src/smefit/rge/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/smefit/rge.py b/src/smefit/rge/rge.py similarity index 92% rename from src/smefit/rge.py rename to src/smefit/rge/rge.py index a180603b..f4320c8a 100644 --- a/src/smefit/rge.py +++ b/src/smefit/rge/rge.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import pathlib +import pickle import warnings from copy import deepcopy from functools import partial, wraps @@ -13,7 +14,7 @@ from smefit import log from smefit.loader import Loader -from smefit.wcxf import inverse_wcxf_translate, wcxf_translate +from smefit.rge.wcxf import inverse_wcxf_translate, wcxf_translate ### Patch of a CKM function, so that the CP violating ### phase is set to gamma and not computed explicitly @@ -412,7 +413,13 @@ def load_scales( def load_rge_matrix( - rge_dict, coeff_list, datasets=None, theory_path=None, cutoff_scale=None + rge_dict, + coeff_list, + datasets=None, + theory_path=None, + cutoff_scale=None, + result_path=None, + result_ID=None, ): """ Load the RGE matrix for the SMEFT Wilson coefficients. @@ -440,12 +447,27 @@ def load_rge_matrix( yukawa = rge_dict.get("yukawa", "top") scale_variation = rge_dict.get("scale_variation", 1.0) rge_runner = RGE(coeff_list, init_scale, smeft_accuracy, adm_QCD, yukawa) + + # load precomputed RGE matrix if it exists + path_to_rge_mat = rge_dict.get("rg_matrix", False) + if path_to_rge_mat: + with open(path_to_rge_mat, "rb") as f: + rgemats = pickle.load(f) + stacked_mats = jnp.stack([rgemat.values for rgemat in rgemats]) + operators_to_keep = {op: {} for op in rgemats[0].index} + return stacked_mats, operators_to_keep + # if it is a float, it is a static scale if type(obs_scale) is float or type(obs_scale) is int: rgemat = rge_runner.RGEmatrix(obs_scale) gen_operators = list(rgemat.index) operators_to_keep = {k: {} for k in gen_operators} - return rgemat.values, operators_to_keep + + # prepend additional dimension for consistency with the dynamic scale case + stacked_mats = jnp.stack([rgemat.values]) + save_rg(pathlib.Path(result_path) / result_ID, rgemat=[rgemat]) + + return stacked_mats, operators_to_keep elif obs_scale == "dynamic": scales = load_scales( @@ -485,9 +507,29 @@ def load_rge_matrix( # now stack the matrices in a 3D array stacked_mats = jnp.stack([mat.values for mat in rgemat]) + + # save RGE matrix to result_path + save_rg(pathlib.Path(result_path) / result_ID, rgemat=rgemat) + return stacked_mats, operators_to_keep else: raise ValueError( f"obs_scale must be either a float/int or 'dynamic'. Passed: {obs_scale}" ) + + +def save_rg(path, rgemat): + """ + Save the RGE matrix to the result folder. + + Parameters + ---------- + path : pathlib.Path + path to the result folder + rgemat: list + List of RGE matrices for each datapoint + """ + if path is not None: + with open(path / "rge_matrix.pkl", "wb") as f: + pickle.dump(rgemat, f) diff --git a/src/smefit/wcxf.py b/src/smefit/rge/wcxf.py similarity index 100% rename from src/smefit/wcxf.py rename to src/smefit/rge/wcxf.py