Skip to content

Commit

Permalink
improving readability
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoterh committed Dec 18, 2024
1 parent 53c8a42 commit c598804
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 42 deletions.
66 changes: 50 additions & 16 deletions src/smefit/analyze/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def color(value, thr_val=10):
@staticmethod
def unify_fishers(df, df_other):

if df_other is None or df is None:
return None

# Get the union of row and column indices
all_rows = df.index.union(df_other.index)
all_columns = df.columns.union(df_other.columns)
Expand Down Expand Up @@ -383,17 +386,44 @@ def set_ticks(ax, yticks, xticks, latex_names, x_labels):

@staticmethod
def plot_values(ax, dfs, cmap, norm):
"""
Plot the values of the Fisher information.
Parameters
----------
ax: matplotlib.axes.Axes
axes object
dfs: list
list of pandas.DataFrame
cmap: matplotlib.colors.LinearSegmentedColormap
colour map
norm: matplotlib.colors.BoundaryNorm
normalisation of colorbar
"""

df_1 = dfs[0]
df_2 = dfs[1] if len(dfs) > 1 else None
cols, rows = df_1.shape

# Initialize the delta shift for text positioning
delta_shift = 0

for i, row in enumerate(df_1.values.T):
for j, elem_1 in enumerate(row):

# start filling from the top left corner
x, y = j, rows - 1 - i
ec_1 = "black"

# if two fishers must be plotted together
if df_2 is not None:

elem_2 = df_2.values.T[i, j]

# move position numbers
delta_shift = 0.2

# highlight operators that exist in one but not the other
ec_1 = "C1" if elem_2 == 0 and elem_1 > 0 else "black"

if elem_2 > 0:
Expand All @@ -406,6 +436,7 @@ def plot_values(ax, dfs, cmap, norm):
fontsize=8,
)

# Create a triangle patch for the second element
triangle2 = Polygon(
[
[x + 0.5, y - 0.5],
Expand All @@ -419,6 +450,7 @@ def plot_values(ax, dfs, cmap, norm):
ax.add_patch(triangle2)

if elem_1 > 0:

ax.text(
x - delta_shift,
y - delta_shift,
Expand All @@ -427,8 +459,9 @@ def plot_values(ax, dfs, cmap, norm):
ha="center",
fontsize=8,
)

if df_2 is not None:

# Create a triangle patch for the first element
triangle1 = Polygon(
[
[x - 0.5, y - 0.5],
Expand All @@ -441,30 +474,24 @@ def plot_values(ax, dfs, cmap, norm):
)
ax.add_patch(triangle1)

# Create legend elements for the patches
legend_elements = [
mpatches.Polygon(
[
[-0.5, -0.5],
[0.5, -0.5],
[-0.5, 0.5],
],
[[-0.5, -0.5], [0.5, -0.5], [-0.5, 0.5]],
closed=True,
fc="none",
edgecolor="black",
label="$\\rm w/\\;RGE$",
),
mpatches.Polygon(
[
[0.5, -0.5],
[0.5, 0.5],
[0.5, 0.5],
],
[[0.5, -0.5], [0.5, 0.5], [0.5, 0.5]],
closed=True,
fc="none",
edgecolor="black",
label="$\\rm w/o\\;RGE$",
),
]
# Add the legend to the plot
ax.legend(
handles=legend_elements,
loc="upper center",
Expand All @@ -474,8 +501,8 @@ def plot_values(ax, dfs, cmap, norm):
handler_map={mpatches.Polygon: HandlerTriangle()},
bbox_to_anchor=(0.5, -0.02),
)

else:
# Create a rectangle patch for the first element
rectangle = Polygon(
[
[x - 0.5, y - 0.5],
Expand All @@ -489,16 +516,18 @@ def plot_values(ax, dfs, cmap, norm):
)
ax.add_patch(rectangle)

# Set the x and y limits of the plot
ax.set_xlim(0, cols - 0.5)
ax.set_ylim(0, rows - 0.5)
# Set the aspect ratio of the plot to be equal
ax.set_aspect("equal", adjustable="box")

def plot_heatmap(

Check notice on line 525 in src/smefit/analyze/fisher.py

View check run for this annotation

codefactor.io / CodeFactor

src/smefit/analyze/fisher.py#L525

Too many local variables (24/22) (too-many-locals)
self,
latex_names,
fig_name,
title=None,
df_other=None,
other=None,
summary_only=True,
figsize=(11, 15),
column_names=None,
Expand All @@ -507,16 +536,21 @@ def plot_heatmap(
fisher_df = self.summary_table if summary_only else self.lin_fisher
quad_fisher_df = self.summary_HOtable if summary_only else self.quad_fisher

if df_other is not None:
if other is not None:

fisher_df_other = other.summary_table if summary_only else other.lin_fisher
quad_fisher_df_other = (
other.summary_HOtable if summary_only else other.quad_fisher
)
# unify the fisher tables and fill missing values by zeros
fisher_dfs = self.unify_fishers(fisher_df, df_other)
fisher_dfs = self.unify_fishers(fisher_df, fisher_df_other)
quad_fisher_dfs = self.unify_fishers(quad_fisher_df, quad_fisher_df_other)

Check notice on line 547 in src/smefit/analyze/fisher.py

View check run for this annotation

codefactor.io / CodeFactor

src/smefit/analyze/fisher.py#L547

Unused variable 'quad_fisher_dfs' (unused-variable)

# reshuffle the tables according to the latex names ordering
fisher_dfs = [
fisher[latex_names.index.get_level_values(level=1)]
for fisher in fisher_dfs
]

else:
fisher_dfs = [fisher_df[latex_names.index.get_level_values(level=1)]]

Expand Down
51 changes: 25 additions & 26 deletions src/smefit/analyze/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,13 +490,7 @@ def pca(
self._append_section("PCA", figs=figs_list, links=links_list)

def fisher(
self,
norm="coeff",
summary_only=True,
plot=None,
fit_list=None,
log=False,
order_rows=False,
self, norm="coeff", summary_only=True, plot=None, fit_list=None, log=False
):
"""Fisher information table and plots runner.
Expand Down Expand Up @@ -524,6 +518,7 @@ def fisher(
else:
fit_list = self.fits

fishers = {}
for fit in fit_list:
compute_quad = fit.config["use_quad"]
fisher_cal = FisherCalculator(fit.coefficients, fit.datasets, compute_quad)
Expand All @@ -534,7 +529,7 @@ def fisher(
fisher_cal.summary_table = fisher_cal.groupby_data(
fisher_cal.lin_fisher, self.data_info, norm, log
)
fit.fisher = fisher_cal.summary_table
fishers[fit.name] = fisher_cal

# if necessary compute the quadratic Fisher
if compute_quad:
Expand All @@ -548,39 +543,43 @@ def fisher(
fisher_cal.quad_fisher, self.data_info, norm, log
)

# Write down the table in latex
free_coeff_config = self.coeff_info

compile_tex(
self.report,
fisher_cal.write_grouped(
free_coeff_config, self.data_info, summary_only
),
fisher_cal.write_grouped(self.coeff_info, self.data_info, summary_only),
f"fisher_{fit.name}",
)
links_list.append((f"fisher_{fit.name}", f"Table {fit.label}"))

if plot is not None:
fit_plot = copy.deepcopy(plot)
fit_plot.pop("together")
title = fit.label if fit_plot.pop("title") else None
fisher_cal.plot_heatmap(
free_coeff_config,
self.coeff_info,
f"{self.report}/fisher_heatmap_{fit.name}",
title=title,
**fit_plot,
)
figs_list.append(f"fisher_heatmap_{fit.name}")

fit_plot = copy.deepcopy(plot)
title = fit.label if fit_plot.pop("title") else None
fisher_cal.plot_heatmap(
free_coeff_config,
f"{self.report}/fisher_heatmap_both",
title=title,
df_other=fit_list[0].fisher,
**fit_plot,
)
figs_list.append(f"fisher_heatmap_both")
# plot both fishers
if plot.get("together", False):
fisher_1 = fishers[plot["together"][0]]
fisher_2 = fishers[plot["together"][1]]
fit_plot = copy.deepcopy(plot)
fit_plot.pop("together")

# show title of last fit
title = fit.label if fit_plot.pop("title") else None

# make heatmap of fisher_1 and fisher_2
fisher_2.plot_heatmap(
self.coeff_info,

Check notice on line 577 in src/smefit/analyze/report.py

View check run for this annotation

codefactor.io / CodeFactor

src/smefit/analyze/report.py#L577

Using possibly undefined loop variable 'fit' (undefined-loop-variable)
f"{self.report}/fisher_heatmap_both",
title=title,
other=fisher_1,
**fit_plot,
)
figs_list.append(f"fisher_heatmap_both")

# self.fits[0].fisher
self._append_section("Fisher", figs=figs_list, links=links_list)

Check notice on line 585 in src/smefit/analyze/report.py

View check run for this annotation

codefactor.io / CodeFactor

src/smefit/analyze/report.py#L585

Using an f-string that does not have any interpolated variables (f-string-without-interpolation)

0 comments on commit c598804

Please sign in to comment.