diff --git a/src/smefit/analyze/fisher.py b/src/smefit/analyze/fisher.py index dfd23393..3317f0e0 100644 --- a/src/smefit/analyze/fisher.py +++ b/src/smefit/analyze/fisher.py @@ -62,10 +62,30 @@ class FisherCalculator: """ - def __init__(self, coefficients, datasets, compute_quad): + def __init__(self, coefficients, datasets, compute_quad, rgemat=None): self.coefficients = coefficients self.free_parameters = self.coefficients.free_parameters.index self.datasets = datasets + self.rgemat = rgemat + + # find name of the operators that are constrained + self.constrained_coeffs = [ + name for name in self.coefficients.name if name not in self.free_parameters + ] + self.constrained_coeffs_idx = [ + idx + for idx, name in enumerate(self.coefficients.name) + if name not in self.free_parameters + ] + + # remove the constrained coefficient from the rgemat both in axis 1 and 2 + if rgemat is not None: + self.rgemat = np.delete( + self.rgemat, [idx for idx in self.constrained_coeffs_idx], axis=1 + ) + self.rgemat = np.delete( + self.rgemat, [idx for idx in self.constrained_coeffs_idx], axis=2 + ) # update eft corrections with the constraints if compute_quad: @@ -77,6 +97,9 @@ def __init__(self, coefficients, datasets, compute_quad): self.new_LinearCorrections = impose_constrain( self.datasets, self.coefficients ) + self.new_LinearCorrectionsNoRGE = impose_constrain( + self.datasets, self.coefficients, norge=True + ).T self.lin_fisher = None self.quad_fisher = None @@ -97,6 +120,29 @@ def compute_linear(self): fisher_tab, index=self.datasets.ExpNames, columns=self.free_parameters ) + def compute_wc_fisher(self): + fisher_tab = [] + operators = [ + name + for name in self.datasets.OperatorsNames + if name not in self.constrained_coeffs + ] + + for i, op in enumerate(operators): + idxs = slice(i, i + 1) + sliced_rgemat = self.rgemat[:, idxs, :] + sigma = np.einsum( + "ij, ijk -> ik", + self.new_LinearCorrectionsNoRGE[:, idxs], + sliced_rgemat, + ) + fisher_row = np.diag(sigma.T @ self.datasets.InvCovMat @ sigma) + fisher_tab.append(fisher_row) + + self.wc_fisher = pd.DataFrame( + fisher_tab, index=operators, columns=self.free_parameters + ) + def compute_quadratic(self, posterior_df, smeft_predictions): """Compute quadratic Fisher information.""" quad_fisher = [] @@ -530,10 +576,13 @@ def plot_heatmap( summary_only=True, figsize=(11, 15), column_names=None, + wc_fisher=False, + wc_to_latex=None, ): - fisher_df = self.summary_table if summary_only else self.lin_fisher quad_fisher_df = self.summary_HOtable if summary_only else self.quad_fisher + if wc_fisher: + wc_fisher_df = self.wc_fisher if other is not None: @@ -562,7 +611,6 @@ def plot_heatmap( quad_fisher_dfs = [ quad_fisher_df[latex_names.index.get_level_values(level=1)] ] - # reshuffle column name ordering if column_names is not None: custom_ordering = [list(column.keys())[0] for column in column_names] @@ -587,36 +635,61 @@ def plot_heatmap( else: ax = plt.gca() - self.plot_values(ax, fisher_dfs, cmap, norm) + if wc_fisher: + wc_fisher_dfs = [ + wc_fisher_df.loc[ + latex_names.index.get_level_values(level=1), + latex_names.index.get_level_values(level=1), + ] + ] - self.set_ticks( - ax, - np.arange(fisher_dfs[0].shape[1]), - np.arange(fisher_dfs[0].shape[0]), - latex_names, - x_labels, - ) - ax.set_title(r"\rm Linear", fontsize=20, y=-0.08) - cax1 = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.5) - colour_bar = fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax1) + self.plot_values(ax, wc_fisher_dfs, cmap, norm) - if quad_fisher_df is not None: - ax = fig.add_subplot(122) - self.plot_values(ax, quad_fisher_dfs, cmap, norm) + self.set_ticks( + ax, + np.arange(wc_fisher_df.shape[1]), + np.arange(wc_fisher_df.shape[0]), + latex_names, + [wc_to_latex[op] for op in wc_fisher_dfs[0].index], + ) + ax.set_title(r"\rm RGE Fisher Information", fontsize=25) + cax1 = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.5) + colour_bar = fig.colorbar( + mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax1 + ) + else: + self.plot_values(ax, fisher_dfs, cmap, norm) self.set_ticks( ax, - np.arange(quad_fisher_dfs[0].shape[1]), - np.arange(quad_fisher_dfs[0].shape[0]), + np.arange(fisher_dfs[0].shape[1]), + np.arange(fisher_dfs[0].shape[0]), latex_names, x_labels, ) - ax.set_title(r"\rm Quadratic", fontsize=20, y=-0.08) + ax.set_title(r"\rm Linear", fontsize=20, y=-0.08) cax1 = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.5) colour_bar = fig.colorbar( mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax1 ) + if quad_fisher_df is not None: + ax = fig.add_subplot(122) + self.plot_values(ax, quad_fisher_dfs, cmap, norm) + + self.set_ticks( + ax, + np.arange(quad_fisher_dfs[0].shape[1]), + np.arange(quad_fisher_dfs[0].shape[0]), + latex_names, + x_labels, + ) + ax.set_title(r"\rm Quadratic", fontsize=20, y=-0.08) + cax1 = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.5) + colour_bar = fig.colorbar( + mpl.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax1 + ) + fig.subplots_adjust(top=0.9) colour_bar.set_label( diff --git a/src/smefit/analyze/pca.py b/src/smefit/analyze/pca.py index 3f411247..aac074fd 100644 --- a/src/smefit/analyze/pca.py +++ b/src/smefit/analyze/pca.py @@ -150,7 +150,7 @@ def make_sym_matrix(vals, n_op): return m -def impose_constrain(dataset, coefficients, update_quad=False): +def impose_constrain(dataset, coefficients, update_quad=False, norge=False): """Propagate coefficient constraint into the theory tables. Note: only linear contraints are allowed in this method. @@ -187,7 +187,14 @@ def impose_constrain(dataset, coefficients, update_quad=False): temp_coeffs.set_constraints() # update linear corrections - new_linear_corrections.append(temp_coeffs.value @ dataset.LinearCorrections.T) + if norge: + new_linear_corrections.append( + temp_coeffs.value @ dataset.LinearCorrectionsNoRGE.T + ) + else: + new_linear_corrections.append( + temp_coeffs.value @ dataset.LinearCorrections.T + ) # update quadratic corrections, this will include some double counting in the mixed corrections if update_quad: diff --git a/src/smefit/analyze/report.py b/src/smefit/analyze/report.py index eab4adc2..76d46801 100644 --- a/src/smefit/analyze/report.py +++ b/src/smefit/analyze/report.py @@ -79,6 +79,12 @@ def __init__(self, report_path, result_path, report_config): self.data_info = self._load_grouped_data_info(report_config["data_info"]) # Loads coefficients grouped with latex name self.coeff_info = self._load_grouped_coeff_info(report_config["coeff_info"]) + + self.coeff_to_latex = {} + for _, entries in report_config["coeff_info"].items(): + for val in entries: + self.coeff_to_latex[val[0]] = val[1] + self.html_index = "" self.html_content = "" @@ -521,11 +527,17 @@ def fisher( fishers = {} for fit in fit_list: compute_quad = fit.config["use_quad"] - fisher_cal = FisherCalculator(fit.coefficients, fit.datasets, compute_quad) + fisher_cal = FisherCalculator( + fit.coefficients, fit.datasets, compute_quad, fit.rgemat + ) fisher_cal.compute_linear() + fisher_cal.compute_wc_fisher() fisher_cal.lin_fisher = fisher_cal.normalize( fisher_cal.lin_fisher, norm=norm, log=log ) + fisher_cal.wc_fisher = fisher_cal.normalize( + fisher_cal.wc_fisher, norm=norm, log=log + ) fisher_cal.summary_table = fisher_cal.groupby_data( fisher_cal.lin_fisher, self.data_info, norm, log ) @@ -558,6 +570,8 @@ def fisher( self.coeff_info, f"{self.report}/fisher_heatmap_{fit.name}", title=title, + wc_fisher=True, + wc_to_latex=self.coeff_to_latex, **fit_plot, ) figs_list.append(f"fisher_heatmap_{fit.name}") diff --git a/src/smefit/loader.py b/src/smefit/loader.py index c490c7cb..c72776ae 100644 --- a/src/smefit/loader.py +++ b/src/smefit/loader.py @@ -22,6 +22,7 @@ "SMTheory", "OperatorsNames", "LinearCorrections", + "LinearCorrectionsNoRGE", "QuadraticCorrections", "ExpNames", "NdataExp", @@ -704,6 +705,10 @@ def load_datasets( lin_corr_list, n_data_tot, sorted_keys, rgemat ) + lin_corr_values_norge = construct_corrections_matrix_linear( + lin_corr_list, n_data_tot, sorted_keys + ) + if use_quad: quad_corr_values = construct_corrections_matrix_quadratic( quad_corr_list, n_data_tot, sorted_keys, rgemat @@ -732,6 +737,7 @@ def load_datasets( np.array(sm_theory), sorted_keys, lin_corr_values, + lin_corr_values_norge, quad_corr_values, np.array(exp_name), np.array(n_data_exp),