diff --git a/pyhdx/plot.py b/pyhdx/plot.py index a8c83e31..ae12f389 100644 --- a/pyhdx/plot.py +++ b/pyhdx/plot.py @@ -1,3 +1,4 @@ +from collections import defaultdict from contextlib import contextmanager from copy import copy from pathlib import Path @@ -552,14 +553,32 @@ def loss_figure(fit_result, **figure_kwargs): def linear_bars_figure( - data, - reference=None, - groupby=None, - field="dG", + data: pd.DataFrame, + reference: Optional[str] = None, + groupby: Optional[str] = None, + field: str = "dG", norm=None, cmap=None, **figure_kwargs, ): + """ + Generate a linear bars figure based on the provided data. + + Args: + data: A pandas DataFrame containing the data to be plotted. + reference: An optional string representing the reference value for subtraction. + groupby: An optional string representing the column to group the data by. + field: A string representing the field to be plotted. Default is "dG". + norm: An optional normalization function. + cmap: An optional colormap. + **figure_kwargs: Additional keyword arguments to be passed to the figure. + + Returns: + fig: The generated figure. + axes: The axes of the figure. + cbar: The colorbar of the figure. + """ + if reference is None and field == "dG": cmap_default, norm_default = CMAP_NORM_DEFAULTS["dG"] ylabel = dG_ylabel @@ -589,14 +608,35 @@ def linear_bars_figure( if norm is None: raise ValueError("No valid Norm found") + + reduced = data.xs(level=-1, key=field, axis=1) + + if groupby: + grp_level = reduced.columns.names.index(groupby) + bar_level = 1 - grp_level + else: + grp_level, bar_level = 0, 1 + + flat = reduced.columns.to_flat_index().tolist() + series_list = [reduced[col] for col in reduced.columns] + + # nest the individual pandas series in a dict according to grp / bar level + result = defaultdict(dict) + for tup, series in zip(flat, series_list): + result[tup[grp_level]][tup[bar_level]] = series + + # subract reference values if given + if reference is not None: + for subdict in result.values(): + ref_values = subdict.pop(reference) + for name, values in subdict.items(): + subdict[name] = values - ref_values + fig, axes, cbar = linear_bars( - data, - reference=reference, - groupby=groupby, - field=field, + result, norm=norm, cmap=cmap, - sclf=sclf, + cbar_sclf=sclf, **figure_kwargs, ) @@ -605,58 +645,35 @@ def linear_bars_figure( return fig, axes, cbar + def linear_bars( - data, - reference=None, - groupby=None, - field="dG", - norm=None, - cmap=None, - sclf=1.0, - sort=False, + data: dict[str, dict[str, pd.Series]], + norm, + cmap, + cbar_sclf=1.0, **figure_kwargs, ): - # input data should always be 3 levels - # grouping is done by the first level - # second level gives each bar - # third level should have columns with the specified 'field' - if data.columns.nlevels == 2: - data = data.copy() - columns = pd.MultiIndex.from_tuples( - [("", *tup) for tup in data.columns], names=["group"] + data.columns.names - ) - data.columns = columns - - # todo this should be done by the 'user' - data = data.xs(level=-1, key=field, drop_level=False, axis=1) - - groupby = groupby or data.columns.names[0] - grp_level = data.columns.names.index(groupby) - bars_level = 1 - grp_level - - if isinstance(reference, int): - reference = data.columns.get_level_values(level=bars_level)[reference] - if reference: - ref_data = data.xs(key=reference, axis=1, level=bars_level) - sub = data.subtract(ref_data, axis=1).reorder_levels(data.columns.names, axis=1) - - # fix column dtypes to preserve , reorder levels back to original order - columns = multiindex_astype(sub.columns, grp_level, "category") - categories = list(data.columns.unique(level=grp_level)) - columns = multiindex_set_categories(columns, grp_level, categories, ordered=True) - sub.columns = columns - sub = sub.sort_index(axis=1, level=grp_level) - sub = sub.drop(axis=1, level=bars_level, labels=reference) - - plot_data = sub - else: - plot_data = data + """ + Generate a linear bar plot with multiple subplots. + + Args: + data: A dictionary containing the data for each subplot. The keys are the top-level labels, and the values are dictionaries containing the data for each subplot. The data is represented as a pandas Series object. + norm: A normalization object to be applied to the color scale. + cmap: A colormap object to be used for coloring the bars. + cbar_sclf: A scaling factor to be applied to the color scale. Default is 1.0. + **figure_kwargs: Additional keyword arguments to be passed to the figure. + + Returns: + fig: The figure object containing the plot. + axes: The axes object containing the subplots. + cbar: The colorbar object. + """ - groups = plot_data.groupby(level=groupby, axis=1).groups - hspace = [elem for v in groups.values() for elem in [0] * (len(v) - 1) + [None]][:-1] + hspace = [elem for v in data.values() for elem in [0] * (len(v) - 1) + [None]][:-1] ncols = 1 nrows = len(hspace) + 1 + figure_width = figure_kwargs.pop("width", cfg.plotting.page_width) / 25.4 refaspect = figure_kwargs.pop("refaspect", cfg.plotting.linear_bars_aspect) cbar_width = figure_kwargs.pop("cbar_width", cfg.plotting.cbar_width) / 25.4 @@ -665,55 +682,32 @@ def linear_bars( nrows=nrows, ncols=ncols, refaspect=refaspect, width=figure_width, hspace=hspace ) axes_iter = iter(axes) - - if sort in ["ascending", "descending"]: - srt_groups = plot_data.groupby(level=bars_level, axis=1).groups - values = [plot_data.loc[:, grp].mean().mean() for grp in srt_groups.values()] - idx = np.argsort(values) - if sort == "descending": - idx = idx[::-1] - else: - idx = None - - for grp_name, items in groups.items(): - items = items.values[idx] if idx is not None else items.values - for i, item in enumerate(items): # these items need to be sorted - values = plot_data.xs(key=(*item[: plot_data.columns.nlevels - 1], field), axis=1) - rmin, rmax = values.index.min(), values.index.max() - extent = [rmin - 0.5, rmax + 0.5, 0, 1] - img = np.expand_dims(values, 0) - + y_edges = [0, 1] + for top_level, subdict in data.items(): + for i, (label, values) in enumerate(subdict.items()): ax = next(axes_iter) - label = item[bars_level] - from matplotlib.axes import Axes - - Axes.imshow( # TODO use proplot pcolor or related function - ax, - norm(img), - aspect="auto", - cmap=cmap, - vmin=0, - vmax=1, - interpolation="None", - extent=extent, - ) + rmin, rmax = values.index.min(), values.index.max() + r_edges = pplt.arange(rmin - 0.5, rmax + 0.5, 1) + ax.pcolormesh(r_edges, y_edges, values.to_numpy().reshape(1, -1), cmap=cmap, norm=norm) ax.format(yticks=[]) + ax.text( - 1.02, - 0.5, - label, - horizontalalignment="left", - verticalalignment="center", - transform=ax.transAxes, - ) + 1.02, + 0.5, + label, + horizontalalignment="left", + verticalalignment="center", + transform=ax.transAxes, + ) + if i == 0: - ax.format(title=grp_name) + ax.format(title=top_level) axes.format(xlabel=r_xlabel) cmap_norm = copy(norm) - cmap_norm.vmin *= sclf - cmap_norm.vmax *= sclf + cmap_norm.vmin *= cbar_sclf + cmap_norm.vmax *= cbar_sclf cbar = fig.colorbar(cmap, norm=cmap_norm, loc="b", width=cbar_width) diff --git a/pyhdx/web/controllers.py b/pyhdx/web/controllers.py index 1084379b..10aa9f8a 100644 --- a/pyhdx/web/controllers.py +++ b/pyhdx/web/controllers.py @@ -2487,6 +2487,7 @@ def _figure_updated(self, *events): self.aspect = cfg.plotting[f"{self.figure}_aspect"] self._excluded = ["ncols", "figure_selection"] + # scatter plot is DG only elif self.figure == "scatter": # move to function df = self.plot_data