Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BasePlot HeatMap classes #106

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ numpy>=1.18.1<=2.10.0
opencv-python>=4.5.4.60
pandana
pandas>=0.25.1
PATSY>=0.5.1
plotly>=5.1.0
POT>=0.8.1
pynndescent>=0.4.8
Expand Down
4 changes: 2 additions & 2 deletions spateo/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,8 +570,8 @@ def set_figure_params(
):
"""Set resolution/size, styling and format of figures.
This function is adapted from: https://github.com/theislab/scanpy/blob/f539870d7484675876281eb1c475595bf4a69bdb/scanpy/_settings.py
Arguments
---------

Args:
spateo: `bool` (default: `True`)
Init default values for :obj:`matplotlib.rcParams` suited for spateo.
background: `str` (default: `white`)
Expand Down
182 changes: 182 additions & 0 deletions spateo/plotting/static/baseplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
from anndata import AnnData
from matplotlib import rcParams
from matplotlib.colors import to_hex

from spateo.tools.utils import update_dict

from ...configuration import SKM
from .utils import _select_font_color, save_fig


@SKM.check_adata_is_type(SKM.ADATA_AGG_TYPE)
class BasePlot:
def __init__(
self,
adata: AnnData,
color: Union[str, list] = "ntr",
layer: Union[str, list] = "X",
basis: Union[str, list] = "umap",
slices: Union[str, list] = None,
slices_split: bool = False,
slices_key: str = "slices",
stack_colors=False,
stack_colors_threshold=0.001,
stack_colors_title="stacked colors",
stack_colors_legend_size=2,
stack_colors_cmaps=None,
ncols: int = 4,
aspect: str = "auto",
axis_on: bool = False,
background: Optional[str] = None,
dpi: int = 100,
figsize: tuple = (6, 4),
gridspec: bool = True,
pointsize: Optional[int] = None,
save_show_or_return: str = "show",
save_kwargs: Optional[dict] = None,
show_legend="on data",
theme: Optional[str] = None,
):
self.adata = adata.copy()
self.stack_colors = stack_colors
self.stack_colors_threshold = stack_colors_threshold
self.stack_colors_title = stack_colors_title
self.stack_colors_legend_size = stack_colors_legend_size
self.show_legend = show_legend
self.aspect = aspect
self.axis_on = axis_on
self.dpi = dpi
self.figsize = figsize
self.save_show_or_return = save_show_or_return
self.save_kwargs = save_kwargs
self.slices_split = slices_split
self.slices_key = slices_key
self.theme = theme
self.basis = self._check_iterable(basis)
self.color = self._check_iterable(color)
self.layer = self._check_iterable(layer)
if slices is None and slices_split:
self.slices = self.adata.obs[self.slices_key].unique().tolist()
self.slices = self._check_iterable(slices)
self.prefix = "baseplot"

if background is None:
_background = rcParams.get("figure.facecolor")
self._background = to_hex(_background) if type(_background) is tuple else _background
else:
self._background = background
self.font_color = _select_font_color(self._background)

if stack_colors and stack_colors_cmaps is None:
self.stack_colors_cmaps = [
"Greys",
"Purples",
"Blues",
"Greens",
"Oranges",
"Reds",
"YlOrBr",
"YlOrRd",
"OrRd",
"PuRd",
"RdPu",
"BuPu",
"GnBu",
"PuBu",
"YlGnBu",
"PuBuGn",
"BuGn",
"YlGn",
]
self.stack_legend_handles = []
if stack_colors:
self.color_key = None

n_s = len(self.slices) if slices_split else 1
n_c = len(self.color) if not stack_colors else 1
n_l = len(self.layer)
n_b = len(self.basis)
total_panels, ncols = (
n_s * n_c * n_l * n_b,
min(max([n_s, n_c, n_l, n_b]), ncols),
)
nrow, ncol = int(np.ceil(total_panels / ncols)), ncols

if pointsize is None:
self.pointsize = 16000.0 / np.sqrt(adata.shape[0])
else:
self.pointsize = 16000.0 / np.sqrt(adata.shape[0]) * pointsize

if gridspec:
if total_panels > 1:
self.fig = plt.figure(
None,
(figsize[0] * ncol, figsize[1] * nrow),
facecolor=self._background,
dpi=self.dpi,
)
self.gs = plt.GridSpec(nrow, ncol, wspace=0.12)
else:
self.fig, ax = plt.subplots(figsize=figsize)
self.gs = [ax]
self.ax_index = 0

def plot(self):
if self.slices_split:
for cur_s in self.slices:
adata = self.adata[self.adata.obs[self.slices_key] == cur_s, :]
for cur_b in self.basis:
for cur_l in self.layer:
for cur_c in self.color:
self._plot_basis_layer(adata, cur_c, cur_b, cur_l)
if not self.stack_colors:
self.ax_index += 1
if self.stack_colors:
self.ax_index += 1

else:
for cur_b in self.basis:
for cur_l in self.layer:
for cur_c in self.color:
self._plot_basis_layer(self.adata, cur_c, cur_b, cur_l)
if not self.stack_colors:
self.ax_index += 1
if self.stack_colors:
self.ax_index += 1

clf = self._save_show_or_return()
return clf

def _plot_basis_layer(self, *args, **kwargs):
raise NotImplementedError

def _save_show_or_return(self):
if self.save_show_or_return in ["save", "both", "all"]:
s_kwargs = {
"path": None,
"prefix": self.prefix,
"dpi": self.dpi,
"ext": "pdf",
"transparent": True,
"close": True,
"verbose": True,
}
s_kwargs = update_dict(s_kwargs, self.save_kwargs)

save_fig(**s_kwargs)
elif self.save_show_or_return in ["show", "both", "all"]:
if self.show_legend:
plt.subplots_adjust(right=0.85)
plt.show()
elif self.save_show_or_return in ["return", "all"]:
return plt.clf()

def _check_iterable(self, arg):
if arg is None or isinstance(arg, str):
return [arg]
else:
return list(arg)
169 changes: 169 additions & 0 deletions spateo/plotting/static/heatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from anndata import AnnData

from ...configuration import SKM
from .baseplot import BasePlot
from .utils import _to_hex


@SKM.check_adata_is_type(SKM.ADATA_AGG_TYPE)
class HeatMap(BasePlot):
def __init__(
self,
adata: AnnData,
markers: list,
group: Optional[str] = None,
group_mean: bool = False,
group_cmap: str = "tab20",
col_cluster: bool = False,
row_cluster: bool = False,
layer: Union[str, list] = "X",
slices: Union[str, list] = None,
slices_split: bool = False,
slices_key: str = "slices",
background: Optional[str] = None,
dpi: int = 100,
figsize: tuple = (11, 5),
save_show_or_return: str = "show",
save_kwargs: Optional[dict] = None,
swap_axis: bool = False,
cbar_pos: Optional[tuple] = None,
theme: Optional[str] = None,
cmap: str = "viridis",
**kwargs
):
super().__init__(
adata=adata,
color=[markers],
basis=group,
layer=layer,
slices=slices,
slices_split=slices_split,
slices_key=slices_key,
background=background,
dpi=dpi,
figsize=figsize,
save_show_or_return=save_show_or_return,
save_kwargs=save_kwargs,
theme=theme,
gridspec=False,
)
self.group_mean = group_mean
self.group_cmap = group_cmap
self.col_cluster = col_cluster
self.row_cluster = row_cluster
self.cmap = cmap
self.cbar_pos = cbar_pos
self.swap_axis = swap_axis
self.kwargs = kwargs

def _plot_basis_layer(self, adata: AnnData, markers, cells_group, cur_l):
value_df, colors = self._fetch_data(adata, markers, cells_group, cur_l)

if self.swap_axis:
value_df = value_df.T

heatmap_kwargs = dict(
xticklabels=1,
yticklabels=False,
col_colors=colors if self.swap_axis else None,
row_colors=None if self.swap_axis else colors,
row_linkage=None,
col_linkage=None,
method="average",
metric="euclidean",
z_score=None,
standard_scale=None,
cbar_pos=self.cbar_pos,
)
if self.kwargs is not None:
heatmap_kwargs.update(self.kwargs)

sns_heatmap = sns.clustermap(
value_df,
col_cluster=self.col_cluster,
row_cluster=self.row_cluster,
cmap=self.cmap,
figsize=self.figsize,
**heatmap_kwargs,
)

# if not self.show_legend:
# sns_heatmap.cax.set_visible(False)

def _fetch_data(self, adata: AnnData, markers, cells_group, cur_l):
layer = None if cur_l == "X" else cur_l
value_df = pd.DataFrame()
for i, marker in enumerate(markers):
v = adata.obs_vector(marker, layer=layer)
value_df[marker] = v
value_df.index = adata.obs.index
colors = None
if cells_group is not None:
value_df[cells_group] = adata.obs_vector(cells_group, layer=layer)
value_df = value_df.sort_values(cells_group)
if self.group_mean:
value_df = value_df.groupby(cells_group, as_index=False).mean()
num_labels = len(value_df[cells_group].unique())

color_key = _to_hex(plt.get_cmap(self.group_cmap)(np.linspace(0, 1, num_labels)))
cell_lut = dict(zip(value_df[cells_group].unique().tolist(), color_key))
colors = value_df[cells_group].map(cell_lut)
value_df = value_df.drop(cells_group, axis=1)

return value_df, colors


@SKM.check_adata_is_type(SKM.ADATA_AGG_TYPE)
def heatmap(
adata: AnnData,
markers: list,
group: Optional[str] = None,
group_mean: bool = False,
group_cmap: str = "tab20",
col_cluster: bool = False,
row_cluster: bool = False,
layer: Union[str, list] = "X",
slices: Union[str, list] = None,
slices_split: bool = False,
slices_key: str = "slices",
background: Optional[str] = None,
dpi: int = 100,
figsize: tuple = (11, 5),
save_show_or_return: str = "show",
save_kwargs: Optional[dict] = None,
swap_axis: bool = False,
cbar_pos: Optional[tuple] = None,
theme: Optional[str] = None,
cmap: str = "viridis",
**kwargs
):
hm = HeatMap(
adata=adata,
markers=markers,
group=group,
group_mean=group_mean,
group_cmap=group_cmap,
col_cluster=col_cluster,
row_cluster=row_cluster,
layer=layer,
slices=slices,
slices_split=slices_split,
slices_key=slices_key,
background=background,
dpi=dpi,
figsize=figsize,
save_show_or_return=save_show_or_return,
save_kwargs=save_kwargs,
swap_axis=swap_axis,
cbar_pos=cbar_pos,
theme=theme,
cmap=cmap,
**kwargs,
)
return hm.plot()
11 changes: 5 additions & 6 deletions spateo/plotting/static/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,9 @@ def space(
*args,
**kwargs
):
"""\
Scatter plot for physical coordinates of each cell.
Parameters
----------
"""Scatter plot for physical coordinates of each cell.

Args:
adata:
an Annodata object that contain the physical coordinates for each bin/cell, etc.
genes:
Expand Down Expand Up @@ -83,8 +82,8 @@ def space(
ps_sample_num: `int`
The number of bins / cells that will be sampled to estimate the distance between different bin / cells.
%(scatters.parameters.no_adata|basis|figsize)s
Returns
-------

Returns:
plots gene or cell feature of the adata object on the physical spatial coordinates.
"""
# main_info("Plotting spatial info on adata")
Expand Down