From 3dafc2ca9711190721beaf9fefccf8492c5389dd Mon Sep 17 00:00:00 2001 From: Jochem Smit Date: Thu, 11 Jan 2024 10:18:29 +0100 Subject: [PATCH] typing and formatting --- pyhdx/plot.py | 119 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 91 insertions(+), 28 deletions(-) diff --git a/pyhdx/plot.py b/pyhdx/plot.py index 1fc1a990..da907562 100644 --- a/pyhdx/plot.py +++ b/pyhdx/plot.py @@ -60,7 +60,7 @@ def peptide_coverage_figure( data: pd.DataFrame, - wrap: int = None, + wrap: Optional[int] = None, cmap: Union[pplt.Colormap, mpl.colors.Colormap, str, tuple, dict] = "turbo", norm: Type[mpl.colors.Normalize] = None, color_field: str = "rfu", @@ -70,7 +70,9 @@ def peptide_coverage_figure( **figure_kwargs, ) -> tuple: subplot_values = data[subplot_field].unique() - sub_dfs = {value: data.query(f"`{subplot_field}` == {value}") for value in subplot_values} + sub_dfs = { + value: data.query(f"`{subplot_field}` == {value}") for value in subplot_values + } n_subplots = len(subplot_values) @@ -86,11 +88,18 @@ def peptide_coverage_figure( start_field, end_field = rect_fields if wrap is None: wrap = max( - [autowrap(sub_df[start_field], sub_df[end_field]) for sub_df in sub_dfs.values()] + [ + autowrap(sub_df[start_field], sub_df[end_field]) + for sub_df in sub_dfs.values() + ] ) fig, axes = pplt.subplots( - ncols=ncols, nrows=nrows, width=figure_width, refaspect=refaspect, **figure_kwargs + ncols=ncols, + nrows=nrows, + width=figure_width, + refaspect=refaspect, + **figure_kwargs, ) rect_kwargs = rect_kwargs or {} axes_iter = iter(axes) @@ -159,13 +168,17 @@ def peptide_coverage( color = cmap(norm(elem[color_field])) width = elem[end_field] - elem[start_field] - rect = Rectangle((elem[start_field] - 0.5, i), width, 1, facecolor=color, **rect_kwargs) + rect = Rectangle( + (elem[start_field] - 0.5, i), width, 1, facecolor=color, **rect_kwargs + ) ax.add_patch(rect) if labels: rx, ry = rect.get_xy() cy = ry cx = rx - ax.annotate(str(p_num), (cx, cy), color="k", fontsize=6, va="bottom", ha="right") + ax.annotate( + str(p_num), (cx, cy), color="k", fontsize=6, va="bottom", ha="right" + ) i -= 1 ax.set_ylim(-wrap, 0) @@ -244,7 +257,9 @@ def residue_time_scatter_figure( return fig, axes, cbars -def residue_time_scatter(ax, hdx_tp, field="rfu", cmap="turbo", norm=None, cbar=True, **kwargs): +def residue_time_scatter( + ax, hdx_tp, field="rfu", cmap="turbo", norm=None, cbar=True, **kwargs +): # update cmap, norm defaults cmap = pplt.Colormap(cmap) # todo allow None as cmap norm = norm or pplt.Norm("linear", vmin=0, vmax=1) @@ -285,13 +300,19 @@ def residue_scatter_figure( tps = np.unique(np.concatenate([hdxm.timepoints for hdxm in hdxm_set])) fig, axes = pplt.subplots( - ncols=ncols, nrows=nrows, width=figure_width, refaspect=refaspect, **figure_kwargs + ncols=ncols, + nrows=nrows, + width=figure_width, + refaspect=refaspect, + **figure_kwargs, ) axes_iter = iter(axes) scatter_kwargs = scatter_kwargs or {} for hdxm in hdxm_set: ax = next(axes_iter) - residue_scatter(ax, hdxm, cmap=cmap, norm=norm, field=field, cbar=False, **scatter_kwargs) + residue_scatter( + ax, hdxm, cmap=cmap, norm=norm, field=field, cbar=False, **scatter_kwargs + ) ax.format(title=f"{hdxm.name}") for ax in axes_iter: @@ -310,7 +331,9 @@ def residue_scatter_figure( # todo allow colorbar_scatter to take rfus -def residue_scatter(ax, hdxm, field="rfu", cmap="viridis", norm=None, cbar=True, **kwargs): +def residue_scatter( + ax, hdxm, field="rfu", cmap="viridis", norm=None, cbar=True, **kwargs +): cmap = pplt.Colormap(cmap) tps = hdxm.timepoints[np.nonzero(hdxm.timepoints)] norm = norm or pplt.Norm("log", tps.min(), tps.max()) @@ -380,7 +403,9 @@ def dG_scatter_figure( # Set global ylims ylims = [lim for ax in axes if ax.axison for lim in ax.get_ylim()] - axes.format(ylim=(np.max(ylims), np.min(ylims)), yticklabelloc="none", ytickloc="none") + axes.format( + ylim=(np.max(ylims), np.min(ylims)), yticklabelloc="none", ytickloc="none" + ) cbar_kwargs = cbar_kwargs or {} cbars = [] @@ -417,7 +442,9 @@ def ddG_scatter_figure( dG_test = data.xs("dG", axis=1, level=1).drop(reference_state, axis=1) dG_ref = data[reference_state, "dG"] ddG = dG_test.subtract(dG_ref, axis=0) - ddG.columns = pd.MultiIndex.from_product([ddG.columns, ["ddG"]], names=["State", "quantity"]) + ddG.columns = pd.MultiIndex.from_product( + [ddG.columns, ["ddG"]], names=["State", "quantity"] + ) cov_test = data.xs("covariance", axis=1, level=1).drop(reference_state, axis=1) ** 2 cov_ref = data[reference_state, "covariance"] ** 2 @@ -490,7 +517,9 @@ def ddG_scatter_figure( return fig, axes, cbars -def peptide_mse_figure(peptide_mse, cmap=None, norm=None, rect_kwargs=None, **figure_kwargs): +def peptide_mse_figure( + peptide_mse, cmap=None, norm=None, rect_kwargs=None, **figure_kwargs +): n_subplots = len(peptide_mse.columns.unique(level=0)) ncols = figure_kwargs.pop("ncols", min(cfg.plotting.ncols, n_subplots)) nrows = figure_kwargs.pop("nrows", int(np.ceil(n_subplots / ncols))) @@ -500,7 +529,11 @@ def peptide_mse_figure(peptide_mse, cmap=None, norm=None, rect_kwargs=None, **fi cmap = cmap or CMAP_NORM_DEFAULTS["mse"][0] fig, axes = pplt.subplots( - ncols=ncols, nrows=nrows, width=figure_width, refaspect=refaspect, **figure_kwargs + ncols=ncols, + nrows=nrows, + width=figure_width, + refaspect=refaspect, + **figure_kwargs, ) axes_iter = iter(axes) cbars = [] @@ -535,7 +568,11 @@ def loss_figure(fit_result, **figure_kwargs): ) # todo loss aspect also in config? fig, ax = pplt.subplots( - ncols=ncols, nrows=nrows, width=figure_width, refaspect=refaspect, **figure_kwargs + ncols=ncols, + nrows=nrows, + width=figure_width, + refaspect=refaspect, + **figure_kwargs, ) fit_result.losses.plot(ax=ax) # ax.plot(fit_result.losses, legend='t') # altnernative proplot plotting @@ -839,7 +876,9 @@ def rainbowclouds( strip_kwargs = _strip_kwargs.update(strip_kwargs) if strip_kwargs else _strip_kwargs kde_kwargs = _kde_kwargs.update(strip_kwargs) if kde_kwargs else _kde_kwargs - boxplot_kwargs = _boxplot_kwargs.update(strip_kwargs) if boxplot_kwargs else _boxplot_kwargs + boxplot_kwargs = ( + _boxplot_kwargs.update(strip_kwargs) if boxplot_kwargs else _boxplot_kwargs + ) stripplot(f_data, ax=ax, **strip_kwargs) kdeplot(f_data, ax=ax, **kde_kwargs) @@ -853,7 +892,9 @@ def rainbowclouds( ytickloc="left", ylim=ylim, ) - format_kwargs = _format_kwargs.update(format_kwargs) if format_kwargs else _format_kwargs + format_kwargs = ( + _format_kwargs.update(format_kwargs) if format_kwargs else _format_kwargs + ) ax.format(**format_kwargs) @@ -1055,7 +1096,9 @@ def add_mse_panels( if cbar: if fig is None: - raise ValueError("Must pass 'fig' keyword argument to add a global colorbar") + raise ValueError( + "Must pass 'fig' keyword argument to add a global colorbar" + ) cbar_kwargs = cbar_kwargs or {} cbar_kwargs = { "width": CBAR_KWARGS["width"], @@ -1135,13 +1178,17 @@ def __init__(self): } colors = ["#6EA72A", "#DAD853", "#FFA842", "#A22D46", "#5D0496"][::-1] - cmap_redundancy = pplt.Colormap(colors, discrete=True, N=len(colors), listmode="discrete") + cmap_redundancy = pplt.Colormap( + colors, discrete=True, N=len(colors), listmode="discrete" + ) cmap_redundancy.set_over("#0E4A21") cmap_redundancy.set_bad(NO_COVERAGE) self.cmaps["redundancy"] = cmap_redundancy colors = ["#008832", "#72D100", "#FFFF04", "#FFB917", "#FF8923"] - cmap_redundancy = pplt.Colormap(colors, discrete=True, N=len(colors), listmode="discrete") + cmap_redundancy = pplt.Colormap( + colors, discrete=True, N=len(colors), listmode="discrete" + ) cmap_redundancy.set_over("#FE2B2E") cmap_redundancy.set_bad(NO_COVERAGE) self.cmaps["resolution"] = cmap_redundancy @@ -1216,7 +1263,9 @@ def pymol_figures( values = values.reindex(pd.RangeIndex(rmin, rmax + 1, name="r_number")) colors = apply_cmap(values, cmap, norm) name = ( - f"pymol_ddG_{state}_ref_{reference_state}" if reference_state else f"pymol_dG_{state}" + f"pymol_ddG_{state}_ref_{reference_state}" + if reference_state + else f"pymol_dG_{state}" ) name += name_suffix pymol_render( @@ -1341,7 +1390,9 @@ def stripplot( for i, (d, color) in enumerate(zip(data, color_list)): jitter_offsets = (np.random.rand(d.size) - 0.5) * jitter - cat_var = i * np.ones_like(d) + jitter_offsets + offset # categorical axis variable + cat_var = ( + i * np.ones_like(d) + jitter_offsets + offset + ) # categorical axis variable if orientation == "vertical": ax.scatter(cat_var, d, color=color, **scatter_kwargs) elif orientation == "horizontal": @@ -1443,7 +1494,9 @@ def kdeplot( color=color, ) elif orientation == "vertical": - ax.fill_betweenx(kde_x, len(data) - cat_var, len(data) - cat_var_zero, color=color) + ax.fill_betweenx( + kde_x, len(data) - cat_var, len(data) - cat_var_zero, color=color + ) if fill_cmap: fill_norm = fill_norm or pplt.Norm("linear") @@ -1539,7 +1592,9 @@ def _make_figure(self, figure_name, **kwargs): # return dictionary # keys: either protein state name (hdxm.name) or 'All states' - figures_dict = {name: function(arg, **kwargs) for name, arg in args_dict.items()} + figures_dict = { + name: function(arg, **kwargs) for name, arg in args_dict.items() + } return figures_dict def make_figure(self, figure_name, **kwargs): @@ -1550,7 +1605,9 @@ def make_figure(self, figure_name, **kwargs): return figures_dict def get_fit_timepoints(self): - all_timepoints = np.concatenate([hdxm.timepoints for hdxm in self.fit_result.hdxm_set]) + all_timepoints = np.concatenate( + [hdxm.timepoints for hdxm in self.fit_result.hdxm_set] + ) # x_axis_type = self.settings.get('fit_time_axis', 'Log') x_axis_type = "Log" # todo configureable @@ -1622,7 +1679,9 @@ def save_figure(self, fig_name, ext=".png", **kwargs): figures_dict = self._make_figure(fig_name, **kwargs) if self.output_path is None: - raise ValueError(f"No output path given when `FitResultPlot` object as initialized") + raise ValueError( + f"No output path given when `FitResultPlot` object as initialized" + ) for name, fig_tup in figures_dict.items(): fig = fig_tup if isinstance(fig_tup, plt.Figure) else fig_tup[0] @@ -1670,7 +1729,9 @@ def plot_fitresults( """ - raise DeprecationWarning("This function is deprecated, use FitResultPlot.plot_all instead") + raise DeprecationWarning( + "This function is deprecated, use FitResultPlot.plot_all instead" + ) # batch results only history_path = fitresult_path / "model_history.csv" output_path = output_path or fitresult_path @@ -1759,7 +1820,9 @@ def plot_fitresults( plt.close(fig) if "dG_scatter" in plots: - fig, axes, cbars = dG_scatter_figure(fitresult.output.df, cmap=dG_cmap, norm=dG_norm) + fig, axes, cbars = dG_scatter_figure( + fitresult.output.df, cmap=dG_cmap, norm=dG_norm + ) for ext in output_type: f_out = output_path / (f"dG_scatter" + ext) plt.savefig(f_out)