Skip to content

Commit

Permalink
fix linear bars plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
Jhsmit committed Oct 31, 2023
1 parent e8611a4 commit efe7f5b
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 95 deletions.
184 changes: 89 additions & 95 deletions pyhdx/plot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from contextlib import contextmanager
from copy import copy
from pathlib import Path
Expand Down Expand Up @@ -552,14 +553,32 @@ def loss_figure(fit_result, **figure_kwargs):


def linear_bars_figure(
data,
reference=None,
groupby=None,
field="dG",
data: pd.DataFrame,
reference: Optional[str] = None,
groupby: Optional[str] = None,
field: str = "dG",
norm=None,
cmap=None,
**figure_kwargs,
):
"""
Generate a linear bars figure based on the provided data.
Args:
data: A pandas DataFrame containing the data to be plotted.
reference: An optional string representing the reference value for subtraction.
groupby: An optional string representing the column to group the data by.
field: A string representing the field to be plotted. Default is "dG".
norm: An optional normalization function.
cmap: An optional colormap.
**figure_kwargs: Additional keyword arguments to be passed to the figure.
Returns:
fig: The generated figure.
axes: The axes of the figure.
cbar: The colorbar of the figure.
"""

if reference is None and field == "dG":
cmap_default, norm_default = CMAP_NORM_DEFAULTS["dG"]
ylabel = dG_ylabel
Expand Down Expand Up @@ -589,14 +608,35 @@ def linear_bars_figure(
if norm is None:
raise ValueError("No valid Norm found")


reduced = data.xs(level=-1, key=field, axis=1)

if groupby:
grp_level = reduced.columns.names.index(groupby)
bar_level = 1 - grp_level
else:
grp_level, bar_level = 0, 1

flat = reduced.columns.to_flat_index().tolist()
series_list = [reduced[col] for col in reduced.columns]

# nest the individual pandas series in a dict according to grp / bar level
result = defaultdict(dict)
for tup, series in zip(flat, series_list):
result[tup[grp_level]][tup[bar_level]] = series

# subract reference values if given
if reference is not None:
for subdict in result.values():
ref_values = subdict.pop(reference)
for name, values in subdict.items():
subdict[name] = values - ref_values

fig, axes, cbar = linear_bars(
data,
reference=reference,
groupby=groupby,
field=field,
result,
norm=norm,
cmap=cmap,
sclf=sclf,
cbar_sclf=sclf,
**figure_kwargs,
)

Expand All @@ -605,58 +645,35 @@ def linear_bars_figure(
return fig, axes, cbar



def linear_bars(
data,
reference=None,
groupby=None,
field="dG",
norm=None,
cmap=None,
sclf=1.0,
sort=False,
data: dict[str, dict[str, pd.Series]],
norm,
cmap,
cbar_sclf=1.0,
**figure_kwargs,
):
# input data should always be 3 levels
# grouping is done by the first level
# second level gives each bar
# third level should have columns with the specified 'field'
if data.columns.nlevels == 2:
data = data.copy()
columns = pd.MultiIndex.from_tuples(
[("", *tup) for tup in data.columns], names=["group"] + data.columns.names
)
data.columns = columns

# todo this should be done by the 'user'
data = data.xs(level=-1, key=field, drop_level=False, axis=1)

groupby = groupby or data.columns.names[0]
grp_level = data.columns.names.index(groupby)
bars_level = 1 - grp_level

if isinstance(reference, int):
reference = data.columns.get_level_values(level=bars_level)[reference]
if reference:
ref_data = data.xs(key=reference, axis=1, level=bars_level)
sub = data.subtract(ref_data, axis=1).reorder_levels(data.columns.names, axis=1)

# fix column dtypes to preserve , reorder levels back to original order
columns = multiindex_astype(sub.columns, grp_level, "category")
categories = list(data.columns.unique(level=grp_level))
columns = multiindex_set_categories(columns, grp_level, categories, ordered=True)
sub.columns = columns
sub = sub.sort_index(axis=1, level=grp_level)
sub = sub.drop(axis=1, level=bars_level, labels=reference)

plot_data = sub
else:
plot_data = data
"""
Generate a linear bar plot with multiple subplots.
Args:
data: A dictionary containing the data for each subplot. The keys are the top-level labels, and the values are dictionaries containing the data for each subplot. The data is represented as a pandas Series object.
norm: A normalization object to be applied to the color scale.
cmap: A colormap object to be used for coloring the bars.
cbar_sclf: A scaling factor to be applied to the color scale. Default is 1.0.
**figure_kwargs: Additional keyword arguments to be passed to the figure.
Returns:
fig: The figure object containing the plot.
axes: The axes object containing the subplots.
cbar: The colorbar object.
"""

groups = plot_data.groupby(level=groupby, axis=1).groups
hspace = [elem for v in groups.values() for elem in [0] * (len(v) - 1) + [None]][:-1]

hspace = [elem for v in data.values() for elem in [0] * (len(v) - 1) + [None]][:-1]
ncols = 1
nrows = len(hspace) + 1

figure_width = figure_kwargs.pop("width", cfg.plotting.page_width) / 25.4
refaspect = figure_kwargs.pop("refaspect", cfg.plotting.linear_bars_aspect)
cbar_width = figure_kwargs.pop("cbar_width", cfg.plotting.cbar_width) / 25.4
Expand All @@ -665,55 +682,32 @@ def linear_bars(
nrows=nrows, ncols=ncols, refaspect=refaspect, width=figure_width, hspace=hspace
)
axes_iter = iter(axes)

if sort in ["ascending", "descending"]:
srt_groups = plot_data.groupby(level=bars_level, axis=1).groups
values = [plot_data.loc[:, grp].mean().mean() for grp in srt_groups.values()]
idx = np.argsort(values)
if sort == "descending":
idx = idx[::-1]
else:
idx = None

for grp_name, items in groups.items():
items = items.values[idx] if idx is not None else items.values
for i, item in enumerate(items): # these items need to be sorted
values = plot_data.xs(key=(*item[: plot_data.columns.nlevels - 1], field), axis=1)
rmin, rmax = values.index.min(), values.index.max()
extent = [rmin - 0.5, rmax + 0.5, 0, 1]
img = np.expand_dims(values, 0)

y_edges = [0, 1]
for top_level, subdict in data.items():
for i, (label, values) in enumerate(subdict.items()):
ax = next(axes_iter)
label = item[bars_level]
from matplotlib.axes import Axes

Axes.imshow( # TODO use proplot pcolor or related function
ax,
norm(img),
aspect="auto",
cmap=cmap,
vmin=0,
vmax=1,
interpolation="None",
extent=extent,
)
rmin, rmax = values.index.min(), values.index.max()
r_edges = pplt.arange(rmin - 0.5, rmax + 0.5, 1)
ax.pcolormesh(r_edges, y_edges, values.to_numpy().reshape(1, -1), cmap=cmap, norm=norm)
ax.format(yticks=[])

ax.text(
1.02,
0.5,
label,
horizontalalignment="left",
verticalalignment="center",
transform=ax.transAxes,
)
1.02,
0.5,
label,
horizontalalignment="left",
verticalalignment="center",
transform=ax.transAxes,
)

if i == 0:
ax.format(title=grp_name)
ax.format(title=top_level)

axes.format(xlabel=r_xlabel)

cmap_norm = copy(norm)
cmap_norm.vmin *= sclf
cmap_norm.vmax *= sclf
cmap_norm.vmin *= cbar_sclf
cmap_norm.vmax *= cbar_sclf

cbar = fig.colorbar(cmap, norm=cmap_norm, loc="b", width=cbar_width)

Expand Down
1 change: 1 addition & 0 deletions pyhdx/web/controllers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2487,6 +2487,7 @@ def _figure_updated(self, *events):
self.aspect = cfg.plotting[f"{self.figure}_aspect"]
self._excluded = ["ncols", "figure_selection"]

# scatter plot is DG only
elif self.figure == "scatter":
# move to function
df = self.plot_data
Expand Down

0 comments on commit efe7f5b

Please sign in to comment.