diff --git a/scarf/assay.py b/scarf/assay.py index 83cf143..10c5aca 100644 --- a/scarf/assay.py +++ b/scarf/assay.py @@ -22,8 +22,6 @@ from .metadata import MetaData from .utils import controlled_compute, logger, show_dask_progress -zarrGroup = z_hierarchy.Group - __all__ = ["Assay", "RNAassay", "ATACassay", "ADTassay"] @@ -102,7 +100,7 @@ class Assay: for later KNN graph construction. Args: - z (zarrGroup): Zarr hierarchy where raw data is located + z (z_hierarchy.Group): Zarr hierarchy where raw data is located name (str): A label/name for assay. cell_data: Metadata class object for the cell attributes. nthreads: number for threads to use for dask parallel computations @@ -122,7 +120,7 @@ class Assay: def __init__( self, - z: zarrGroup, + z: z_hierarchy.Group, workspace: Union[str, None], name: str, # FIXME change to assay_name cell_data: MetaData, @@ -757,7 +755,7 @@ class RNAassay(Assay): normalization of scRNA-Seq data. Args: - z (zarrGroup): Zarr hierarchy where raw data is located + z (z_hierarchy.Group): Zarr hierarchy where raw data is located name (str): A label/name for assay. cell_data: Metadata class object for the cell attributes. **kwargs: kwargs to be passed to the Assay class @@ -769,7 +767,7 @@ class RNAassay(Assay): It is set to None until normed method is called. """ - def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs): + def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs): super().__init__(z=z, name=name, cell_data=cell_data, **kwargs) self.normMethod = norm_lib_size if "size_factor" in self.attrs: @@ -1076,12 +1074,12 @@ class ATACassay(Assay): """This subclass of Assay is designed for feature selection and normalization of scATAC-Seq data.""" - def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs): + def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs): """This Assay subclass is designed for feature selection and normalization of scATAC-Seq data. Args: - z (zarrGroup): Zarr hierarchy where raw data is located + z (z_hierarchy.Group): Zarr hierarchy where raw data is located name (str): A label/name for assay. cell_data: Metadata class object for the cell attributes. **kwargs: @@ -1208,7 +1206,7 @@ class ADTassay(Assay): (feature-barcodes library) data from CITE-Seq experiments. Args: - z (zarrGroup): Zarr hierarchy where raw data is located + z (z_hierarchy.Group): Zarr hierarchy where raw data is located name (str): A label/name for assay. cell_data: Metadata class object for the cell attributes. **kwargs: @@ -1217,7 +1215,7 @@ class ADTassay(Assay): normMethod: Pointer to the function to be used for normalization of the raw data """ - def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs): + def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs): """This subclass of Assay is designed for normalization of ADT/HTO (feature-barcodes library) data from CITE-Seq experiments.""" super().__init__(z=z, name=name, cell_data=cell_data, **kwargs) diff --git a/scarf/datastore/datastore.py b/scarf/datastore/datastore.py index b803ae5..bbefeb8 100644 --- a/scarf/datastore/datastore.py +++ b/scarf/datastore/datastore.py @@ -1,15 +1,15 @@ -from typing import Iterable, Optional, Union, List, Literal, Tuple +from typing import Iterable, List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd from dask import array as daskarr from loguru import logger -from .mapping_datastore import MappingDatastore -from ..assay import Assay, RNAassay, ATACassay +from ..assay import Assay, ATACassay, RNAassay from ..feat_utils import hto_demux -from ..utils import tqdmbar, controlled_compute, ZARRLOC -from ..writers import create_zarr_obj_array, create_zarr_dataset +from ..utils import ZARRLOC, controlled_compute, tqdmbar +from ..writers import create_zarr_dataset, create_zarr_obj_array +from .mapping_datastore import MappingDatastore __all__ = ["DataStore"] @@ -75,10 +75,7 @@ def __init__( synchronizer=synchronizer, ) - def get_assay( - self, - assay_name: str - ) -> Assay: + def get_assay(self, assay_name: str) -> Assay: """Returns the assay object for the given assay name. Args: @@ -91,8 +88,7 @@ def get_assay( raise ValueError(f"ERROR: Assay {assay_name} not found in the Zarr file") else: return getattr(self, assay_name) - - + def filter_cells( self, attrs: Iterable[str], @@ -282,7 +278,7 @@ def mark_hvgs( if cell_key is None: cell_key = "I" assay = self._get_assay(from_assay) - if type(assay) != RNAassay: + if type(assay) != RNAassay: # noqa: E721 raise TypeError( f"ERROR: This method of feature selection can only be applied to RNAassay type of assay. " f"The provided assay is {type(assay)} type" @@ -338,7 +334,7 @@ def mark_prevalent_peaks( if cell_key is None: cell_key = "I" assay = self._get_assay(from_assay) - if type(assay) != ATACassay: + if type(assay) != ATACassay: # noqa: E721 raise TypeError( f"ERROR: This method of feature selection can only be applied to ATACassay type of assay. " f"The provided assay is {type(assay)} type" @@ -625,8 +621,8 @@ def get_markers( cell_key = "I" if group_key is None: raise ValueError( - f"ERROR: Please provide a value for group_key. " - f"This should be same as used for `run_marker_search`" + "ERROR: Please provide a value for group_key. " + "This should be same as used for `run_marker_search`" ) assay = self._get_assay(from_assay) try: @@ -706,8 +702,8 @@ def export_markers_to_csv( # Not testing the values of from_assay and cell_key because they will be tested in `get_markers` if group_key is None: raise ValueError( - f"ERROR: Please provide a value for group_key. " - f"This should be same as used for `run_marker_search`" + "ERROR: Please provide a value for group_key. " + "This should be same as used for `run_marker_search`" ) if csv_filename is None: raise ValueError( @@ -1290,7 +1286,7 @@ def plot_cells_dists( pass if cols is not None: - if type(cols) != list: + if type(cols) != list: # noqa: E721 raise ValueError("ERROR: 'cols' argument must be of type list") plot_cols = [] for i in cols: @@ -1496,7 +1492,7 @@ def plot_layout( # TODO: add support for providing a list of subselections, from_assay and cell_keys # TODO: add support for different kinds of point markers - from ..plots import shade_scatter, plot_scatter + from ..plots import plot_scatter, shade_scatter if from_assay is None: from_assay = self._defaultAssay @@ -1728,9 +1724,10 @@ def plot_cluster_tree( None """ - from ..plots import plot_cluster_hierarchy + from networkx import DiGraph, to_pandas_edgelist + from ..dendrogram import CoalesceTree, make_digraph - from networkx import to_pandas_edgelist, DiGraph + from ..plots import plot_cluster_hierarchy from_assay, cell_key, feat_key = self._get_latest_keys( from_assay, cell_key, feat_key @@ -2011,8 +2008,8 @@ def plot_pseudotime_heatmap( else: if hashes != assay.z[location].attrs["hashes"]: raise ValueError( - f"ERROR: The values under one or more of these columns: `cell_key`, `feat_key` or/and " - f"`pseudotime_key have been updated after running `run_pseudotime_aggregation`" + "ERROR: The values under one or more of these columns: `cell_key`, `feat_key` or/and " + "`pseudotime_key have been updated after running `run_pseudotime_aggregation`" ) da = daskarr.from_zarr(assay.z[location + "/data"], inline_array=True) @@ -2049,3 +2046,202 @@ def plot_pseudotime_heatmap( save_dpi=save_dpi, show_fig=show_fig, ) + + def metric_lisi( + self, + label_colnames: Iterable[str], + use_latest_knn: bool = True, + from_assay: Optional[str] = None, + knn_loc: Optional[str] = None, + save_result: bool = False, + return_lisi: bool = True, + ) -> Optional[List[Tuple[str, np.ndarray]]]: + """Calculate Local Inverse Simpson Index (LISI) scores for cell populations. + + LISI measures how well mixed different cell populations are in the local neighborhood + of each cell. Higher scores indicate better mixing of different populations. + + Args: + label_colnames: Column names from cell metadata containing population labels + use_latest_knn: Whether to use the most recent KNN graph (default: True) + from_assay: Name of assay to use if not using latest KNN + knn_loc: Location of KNN graph if not using latest (default: None) + save_result: Whether to save LISI scores to cell metadata (default: True) + return_lisi: Whether to return LISI scores (default: False) + + Returns: + If return_lisi is True, returns list of tuples containing: + - Label column name + - numpy array of LISI scores for that label + If return_lisi is False, returns None + + Raises: + ValueError: If using custom KNN graph but required parameters are missing + KeyError: If label columns not found in cell metadata + + Notes: + LISI scores are computed for each label column separately. + Scores near 1 indicate cells grouped with similar labels. + Higher scores indicate more mixing between different labels. + """ + + if use_latest_knn and knn_loc is None: + knn_loc = self._get_latest_knn_loc(from_assay) + cell_key = self.zw[self._load_default_assay()].attrs["latest_cell_key"] + logger.info(f"Using the latest knn graph at location: {knn_loc}") + + else: + if knn_loc is None: + raise ValueError("Please provide values for the KNN graph location.") + if knn_loc not in self.zw: + raise ValueError(f"Could not find the knn graph at location: {knn_loc}") + + logger.info(f"Using the knn graph at location: {knn_loc}") + + knn = self.zw[knn_loc] + + distances = knn["distances"] + indices = knn["indices"] + + try: + metadata = self.cells.to_pandas_dataframe( + columns=label_colnames + [cell_key] + ) + metadata = metadata[metadata[cell_key]] + except KeyError: + raise KeyError( + f"Could not find the column(s) {label_colnames} in the cell metadata table." + ) + + from ..metrics import compute_lisi + + lisi_scores = compute_lisi(distances, indices, metadata, label_colnames) + # lisi_scores Shape -> (n_cells, n_labels) + if save_result: + for col, vals in zip(label_colnames, lisi_scores.T): + col_name = f"lisi__{col}__{knn_loc.split('/')[-1]}" + self.cells.insert(column_name=col_name, values=vals, overwrite=True) + + if return_lisi: + return list(zip(label_colnames, lisi_scores.T)) + else: + return None + + def metric_silhouette( + self, + use_latest_knn: bool = True, + res_label: str = "leiden_cluster", + from_assay: Optional[str] = None, + knn_loc: Optional[str] = None, + ) -> Optional[np.ndarray]: + """Calculate modified silhouette scores for evaluating cluster separation. + + This implements a graph-based silhouette score that measures how similar cells + are to their own cluster compared to the nearest neighboring cluster. + + Args: + use_latest_knn: Whether to use most recent KNN graph (default: True) + res_label: Column name containing cluster labels (default: "leiden_cluster") + from_assay: Name of assay to use if not using latest KNN (default: None) + knn_loc: Location of KNN graph if not using latest (default: None) + + Returns: + numpy array of silhouette scores for each cluster, or None if computation fails + + Raises: + ValueError: If using custom KNN graph but required parameters are missing + + Notes: + Scores range from -1 to 1: + - Near 1: Cluster is well-separated from neighboring clusters + - Near 0: Cluster overlaps with neighboring clusters + - Near -1: Cluster may be incorrectly assigned + + Implementation uses sampling for efficiency with large datasets. + NaN values indicate clusters that couldn't be scored due to size constraints. + """ + + def compute_graph_feats(knn_loc: str): + k = knn_loc.rsplit("/", 1)[-1].split("__")[-1] + dims = knn_loc.rsplit("/", 2)[0].split("__")[-2] + feat_key = knn_loc.split("/")[1].split("__")[-1] + return k, dims, feat_key + + if from_assay is None: + from_assay = self._load_default_assay() + + if use_latest_knn and knn_loc is None: + knn_loc = self._get_latest_knn_loc(from_assay) + k, dims, feat_key = compute_graph_feats(knn_loc) + logger.info( + f"Using the latest knn graph at location: {knn_loc} for assay: {from_assay}" + ) + + else: + if knn_loc is None: + raise ValueError("Please provide values for the KNN graph location.") + if knn_loc not in self.zw: + raise ValueError(f"Could not find the knn graph at location: {knn_loc}") + k, dims, feat_key, from_assay = compute_graph_feats(knn_loc) + logger.info(f"Using the knn graph at location: {knn_loc}") + + from ..metrics import knn_to_csr_matrix, silhouette_scoring + + isHarmonized = self.zw[knn_loc.rsplit("/", 1)[0]].attrs["isHarmonized"] + + batches = None + if isHarmonized: + batches = self.zw[knn_loc.rsplit("/", 2)[0] + "/harmonizedData"].attrs[ + "batches" + ] + + ann_obj = self.make_graph( + feat_key=feat_key, + dims=dims, + k=k, + return_ann_object=True, + harmonize=isHarmonized, + batch_columns=batches, + ) + + graph = knn_to_csr_matrix(self.z[knn_loc].indices, self.z[knn_loc].distances) + hvg_data = self.z[knn_loc.rsplit("/", 3)[0] + "/data"] + scores = silhouette_scoring( + self, ann_obj, graph, hvg_data, from_assay, res_label + ) + return scores + + def metric_integration( + self, batch_labels: List[str], metric: Literal["ari", "nmi"] = "ari" + ) -> Optional[float]: + """Calculate integration score between different batch labels. + + Measures how well aligned different batches are after integration by comparing + their cluster assignments using standard clustering metrics. + + Args: + batch_labels: List of column names containing batch labels to compare + metric: Metric to use for comparison (default: "ari") + - "ari": Adjusted Rand Index + - "nmi": Normalized Mutual Information + + Returns: + Integration score between 0 and 1, or None if metric is not recognized + Higher scores indicate better alignment between batches + + Notes: + ARI and NMI measure the agreement between different batch labelings: + - Score near 1: Batches are well integrated + - Score near 0: Batches show poor integration + + ARI is adjusted for chance and generally more stringent than NMI. + """ + from ..metrics import integration_score + + batch_labels_vals = [] + for batch in batch_labels: + vals = np.array(self.cells.fetch_all(batch)) + batch_labels_vals.append(vals) + + scores = integration_score(batch_labels_vals, metric) + return scores diff --git a/scarf/datastore/graph_datastore.py b/scarf/datastore/graph_datastore.py index 5235f94..7ba0b1d 100644 --- a/scarf/datastore/graph_datastore.py +++ b/scarf/datastore/graph_datastore.py @@ -1,16 +1,16 @@ import os -from typing import Tuple, Optional, Union, List, Callable +from typing import Callable, List, Optional, Tuple, Union import numpy as np import pandas as pd from dask.array import from_zarr # type: ignore from loguru import logger -from scipy.sparse import csr_matrix, coo_matrix +from scipy.sparse import coo_matrix, csr_matrix -from .base_datastore import BaseDataStore from ..assay import Assay from ..utils import clean_array, show_dask_progress, system_call, tqdmbar from ..writers import create_zarr_dataset +from .base_datastore import BaseDataStore class GraphDataStore(BaseDataStore): @@ -396,6 +396,34 @@ def _get_latest_graph_loc( knn_loc = self.zw[ann_loc].attrs["latest_knn"] return self.zw[knn_loc].attrs["latest_graph"] + def _get_latest_knn_loc(self, from_assay: str = None) -> str: + """Convenience function to identify location of the latest KNN graph in + the Zarr hierarchy. + + Args: + from_assay: Name of the assay. + + Returns: + Path of KNN graph in the Zarr hierarchy + """ + if from_assay is None: + logger.info("Using default assay for KNN graph.") + from_assay = self._load_default_assay() + + if from_assay not in self.assay_names: + raise ValueError(f"ERROR: Assay {from_assay} does not exist") + + latest_cell_key = self.zw[from_assay].attrs["latest_cell_key"] + latest_feat_key = self.zw[from_assay].attrs["latest_feat_key"] + normed_loc = f"{from_assay}/normed__{latest_cell_key}__{latest_feat_key}" + reduction_loc = self.zw[normed_loc].attrs["latest_reduction"] + if "reduction" not in self.zw[reduction_loc]: + raise ValueError(f"ERROR: PCA Reduction not found in {reduction_loc}") + latest_ann = self.zw[reduction_loc].attrs["latest_ann"] + ann_loc = self.zw[latest_ann] + latest_knn = ann_loc.attrs["latest_knn"] + return latest_knn + def _get_ini_embed( self, from_assay: str, cell_key: str, feat_key: str, n_comps: int ) -> np.ndarray: @@ -414,6 +442,7 @@ def _get_ini_embed( Matrix with n_comps dimensions representing initial embedding of cells. """ from sklearn.decomposition import PCA + from ..utils import rescale_array normed_loc = f"{from_assay}/normed__{cell_key}__{feat_key}" diff --git a/scarf/metrics.py b/scarf/metrics.py new file mode 100644 index 0000000..fe91a20 --- /dev/null +++ b/scarf/metrics.py @@ -0,0 +1,465 @@ +""" +Methods and classes for evluation +""" + +from typing import Iterable, Optional, Sequence, Tuple, Union + +import numpy as np +import pandas as pd +from scipy.sparse import csr_matrix +from zarr.core import Array as zarrArrayType + +from .ann import AnnStream +from .datastore.datastore import DataStore +from .utils import ( + logger, + tqdmbar, +) + + +# LISI - The Local Inverse Simpson Index +def compute_lisi( + distances: zarrArrayType, + indices: zarrArrayType, + metadata: pd.DataFrame, + label_colnames: Iterable[str], + perplexity: float = 30, +) -> np.ndarray: + """Compute the Local Inverse Simpson Index (LISI) for each column in metadata. + + LISI measures how well mixed different groups of cells are in the neighborhood of each cell. + Higher values indicate better mixing of different groups. + + Args: + distances: Pre-computed distances between cells, stored in zarr array format + indices: Pre-computed nearest neighbor indices, stored in zarr array format + metadata: DataFrame containing categorical labels for each cell + label_colnames: Column names in metadata to compute LISI for + perplexity: Parameter controlling the effective number of neighbors (default: 30) + + Returns: + np.ndarray: Matrix of LISI scores with shape (n_cells, n_labels) + Each column corresponds to LISI scores for one label column in metadata + + Example: + For metadata with a 'batch' column having 3 categories: + - LISI ≈ 3: Cell has neighbors from all 3 batches (well mixed) + - LISI ≈ 1: Cell has neighbors from only 1 batch (poorly mixed) + + References: + Korsunsky et al. 2019 doi: 10.1038/s41592-019-0619-0 + """ + + n_cells = metadata.shape[0] + n_labels = len(label_colnames) + # Don't count yourself + indices = indices[:, 1:] + distances = distances[:, 1:] + lisi_df = np.zeros((n_cells, n_labels)) + for i, label in enumerate(label_colnames): + logger.info(f"Computing LISI for {label}") + labels = pd.Categorical(metadata[label]) + n_categories = len(labels.categories) + simpson = compute_simpson( + distances.T, indices.T, labels, n_categories, perplexity + ) + lisi_df[:, i] = 1 / simpson + return lisi_df + + +def compute_simpson( + distances: np.ndarray, + indices: np.ndarray, + labels: pd.Categorical, + perplexity: float, + tol: float = 1e-5, +) -> np.ndarray: + """Compute Simpson's diversity index with Gaussian kernel weighting. + + This function implements the core calculation for LISI, computing a diversity score + based on the distribution of categories in each cell's neighborhood. + + Args: + distances: Distance matrix between points, shape (n_neighbors, n_points) + indices: Index matrix for nearest neighbors, shape (n_neighbors, n_points) + labels: Categorical labels for each point + perplexity: Target perplexity for Gaussian kernel + tol: Convergence tolerance for perplexity calibration (default: 1e-5) + + Returns: + np.ndarray: Array of Simpson's diversity indices, one per point + """ + n = distances.shape[1] + P = np.zeros(distances.shape[0]) + simpson = np.zeros(n) + logU = np.log(perplexity) + # Loop through each cell. + for i in tqdmbar(range(n), desc="Computing Simpson's Diversity Index"): + beta = 1 + betamin = -np.inf + betamax = np.inf + # Compute Hdiff + P = np.exp(-distances[:, i] * beta) + P_sum = np.sum(P) + if P_sum == 0: + H = 0 + P = np.zeros(distances.shape[0]) + else: + H = np.log(P_sum) + beta * np.sum(distances[:, i] * P) / P_sum + P = P / P_sum + Hdiff = H - logU + n_tries = 50 + for t in range(n_tries): + # Stop when we reach the tolerance + if abs(Hdiff) < tol: + break + # Update beta + if Hdiff > 0: + betamin = beta + if not np.isfinite(betamax): + beta *= 2 + else: + beta = (beta + betamax) / 2 + else: + betamax = beta + if not np.isfinite(betamin): + beta /= 2 + else: + beta = (beta + betamin) / 2 + # Compute Hdiff + P = np.exp(-distances[:, i] * beta) + P_sum = np.sum(P) + if P_sum == 0: + H = 0 + P = np.zeros(distances.shape[0]) + else: + H = np.log(P_sum) + beta * np.sum(distances[:, i] * P) / P_sum + P = P / P_sum + Hdiff = H - logU + # distancesefault value + if H == 0: + simpson[i] = -1 + # Simpson's index + for label_category in labels.categories: + ix = indices[:, i] + q = labels[ix] == label_category + if np.any(q): + P_sum = np.sum(P[q]) + simpson[i] += P_sum * P_sum + return simpson + + +# SILHOUETTE SCORE - The Silhouette Score +def knn_to_csr_matrix( + neighbor_indices: np.ndarray, neighbor_distances: np.ndarray +) -> csr_matrix: + """Convert k-nearest neighbors data to a Compressed Sparse Row (CSR) matrix. + + Creates a sparse adjacency matrix representation of the k-nearest neighbors graph + where edge weights are the distances between points. + + Args: + neighbor_indices: Indices matrix from k-nearest neighbors, shape (n_samples, k) + neighbor_distances: Distances matrix from k-nearest neighbors, shape (n_samples, k) + + Returns: + scipy.sparse.csr_matrix: Sparse adjacency matrix of shape (n_samples, n_samples) + where non-zero entries represent distances between neighboring points + """ + num_samples, num_neighbors = neighbor_indices.shape + row_indices = np.repeat(np.arange(num_samples), num_neighbors) + return csr_matrix( + (neighbor_distances[:].flatten(), (row_indices, neighbor_indices[:].flatten())), + shape=(num_samples, num_samples), + ) + + +def calculate_weighted_cluster_similarity( + knn_graph: csr_matrix, cluster_labels: np.ndarray +) -> np.ndarray: + """Calculate similarity between clusters based on shared weighted edges. + + Uses a weighted Jaccard index to compute similarities between clusters in a KNN graph. + + Args: + knn_graph: CSR matrix representing the KNN graph, shape (n_samples, n_samples) + cluster_labels: Cluster assignments for each node, must be contiguous integers + starting from 0 + + Returns: + np.ndarray: Symmetric matrix of shape (n_clusters, n_clusters) containing + pairwise similarities between clusters + + Raises: + AssertionError: If cluster labels are not contiguous integers starting from 0 + """ + unique_cluster_ids = np.unique(cluster_labels) + expected_cluster_ids = np.arange(0, len(unique_cluster_ids)) + assert np.array_equal(unique_cluster_ids, expected_cluster_ids), ( + "Cluster labels must be contiguous integers starting at 1" + ) + + num_clusters = len(unique_cluster_ids) + inter_cluster_weights = np.zeros((num_clusters, num_clusters)) + + for cluster_id in unique_cluster_ids: + nodes_in_cluster = np.where(cluster_labels == cluster_id)[0] + neighbor_cluster_labels = cluster_labels[knn_graph[nodes_in_cluster].indices] + neighbor_edge_weights = knn_graph[nodes_in_cluster].data + + for neighbor_cluster, edge_weight in zip( + neighbor_cluster_labels, neighbor_edge_weights + ): + inter_cluster_weights[cluster_id, neighbor_cluster] += edge_weight + + # assert inter_cluster_weights.sum() == knn_graph.data.sum() + + # Ensure symmetry + inter_cluster_weights = (inter_cluster_weights + inter_cluster_weights.T) / 2 + + # Calculate total weights for each cluster + total_cluster_weights = np.array( + [inter_cluster_weights[i - 1].sum() for i in unique_cluster_ids] + ) + + # Calculate similarity using weighted Jaccard index + similarity_matrix = np.zeros((num_clusters, num_clusters)) + + for i in range(num_clusters): + for j in range(i, num_clusters): + weight_union = ( + total_cluster_weights[i] + + total_cluster_weights[j] + - inter_cluster_weights[i, j] + ) + if weight_union > 0: + similarity = inter_cluster_weights[i, j] / weight_union + similarity_matrix[i, j] = similarity_matrix[j, i] = similarity + + # Set diagonal to 1 (self-similarity) + # np.fill_diagonal(similarity_matrix, 1.0) + + return similarity_matrix + + +def calculate_top_k_neighbor_distances( + matrix_a: np.ndarray, matrix_b: np.ndarray, k: int +) -> np.ndarray: + """Calculate distances to k nearest neighbors between two sets of points. + + For each point in matrix_a, finds the k nearest neighbors in matrix_b + and returns their distances. + + Args: + matrix_a: First set of points, shape (m, d) + matrix_b: Second set of points, shape (n, d) + k: Number of nearest neighbors to find + + Returns: + np.ndarray: Matrix of shape (m, k) containing the distances to the + k nearest neighbors in matrix_b for each point in matrix_a + + Raises: + AssertionError: If matrices don't have the same number of features + """ + # Check if the matrices have the same number of features (d) + assert matrix_a.shape[1] == matrix_b.shape[1], ( + "Matrices must have the same number of features" + ) + + # Ensure k is not larger than the number of points in matrix_b + k = min(k, matrix_b.shape[0]) + + # Calculate squared Euclidean distances + a_squared = np.sum(np.square(matrix_a), axis=1, keepdims=True) + b_squared = np.sum(np.square(matrix_b), axis=1) + + # Use broadcasting to compute pairwise distances + distances = a_squared + b_squared - 2 * np.dot(matrix_a, matrix_b.T) + + # Use np.maximum to avoid small negative values due to floating point errors + distances = np.maximum(distances, 0) + + # Find the k smallest distances for each point in matrix_a + top_k_distances = np.partition(distances, k, axis=1)[:, :k] + + # Calculate the square root to get Euclidean distances + return np.sqrt(top_k_distances) + + +def process_cluster( + cluster_cells: np.ndarray, + hvg_data: Union[np.ndarray, zarrArrayType], + ann_obj: AnnStream, + k: int, +) -> Tuple[np.ndarray, np.ndarray]: + """Process a cluster of cells to prepare data for silhouette scoring. + + Randomly splits cluster cells into two groups and applies dimensionality reduction. + + Args: + cluster_cells: Indices of cells belonging to the cluster + hvg_data: Expression data for highly variable genes + ann_obj: Object containing dimensionality reduction method + k: Number of cells to sample from cluster + + Returns: + Tuple[np.ndarray, np.ndarray]: Two arrays containing reduced data for + different subsets of cells from the cluster + """ + np.random.shuffle(cluster_cells) + data_cells = np.array( + [ann_obj.reducer(hvg_data[i]) for i in sorted(cluster_cells[:k])] + ) + data_cells_2 = np.array( + [ann_obj.reducer(hvg_data[i]) for i in sorted(cluster_cells[k : 2 * k])] + ) + return data_cells, data_cells_2 + + +def silhouette_scoring( + ds: DataStore, + ann_obj: AnnStream, + graph: csr_matrix, + hvg_data: Union[np.ndarray, zarrArrayType], + assay_type: str, + res_label: str, +) -> Optional[np.ndarray]: + """Compute modified silhouette scores for clusters in single-cell data. + + This implementation differs from the standard silhouette score by using + a graph-based approach and comparing clusters to their nearest neighbors. + + Args: + ds: DataStore object containing cell metadata + ann_obj: Object containing dimensionality reduction method + graph: CSR matrix representing the KNN graph + hvg_data: Expression data for highly variable genes + assay_type: Type of assay (e.g., 'RNA', 'ATAC') + res_label: Label for clustering resolution + + Returns: + Optional[np.ndarray]: Array of silhouette scores for each cluster, + or None if cluster labels are not found + + Notes: + Scores are calculated using a sampling approach for efficiency. + NaN values indicate clusters that couldn't be scored due to size constraints. + """ + try: + clusters = ds.cells.fetch(f"{assay_type}_{res_label}") - 1 # RNA_{res_label} + except KeyError: + logger.error(f"Cluster labels not found for {assay_type}_{res_label}") + return None + + cluster_similarity = calculate_weighted_cluster_similarity(graph, clusters) + + k = 11 + score = [] + + for n, i in enumerate(cluster_similarity): + this_cluster_cells = np.where(clusters == n)[0] + if len(this_cluster_cells) < 2 * k: + k = int(len(this_cluster_cells) / 2) + logger.warning( + f"Warning: Cluster {n} has fewer than 22 cells. Will adjust k to {k} instead" + ) + + for n, i in tqdmbar( + enumerate(cluster_similarity), + total=len(cluster_similarity), + desc="Calculating Silhouette Scores", + ): + this_cluster_cells = np.where(clusters == n)[0] + np.random.shuffle(this_cluster_cells) + data_this_cells, data_this_cells_2 = process_cluster( + # n, + this_cluster_cells, + hvg_data, + ann_obj, + k, + ) + + if data_this_cells.size == 0 or data_this_cells_2.size == 0: + logger.warning(f"Warning: Reduced data for cluster {n} is empty. Skipping.") + score.append(np.nan) + continue + + k_neighbors = min(k - 1, data_this_cells_2.shape[0] - 1) + + if k_neighbors < 1: + logger.warning( + f"Warning: Not enough points in cluster {n} for comparison. Skipping." + ) + score.append(np.nan) + continue + + self_dist = calculate_top_k_neighbor_distances( + data_this_cells, data_this_cells_2, k - 1 + ).mean() + + nearest_cluster = np.argsort(i)[-1] + nearest_cluster_cells = np.where(clusters == nearest_cluster)[0] + np.random.shuffle(nearest_cluster_cells) + + if len(nearest_cluster_cells) < k: + logger.warning( + f"Warning: Nearest cluster {nearest_cluster} has fewer than {k} cells. Skipping." + ) + score.append(np.nan) + continue + + data_nearest_cells, _ = process_cluster( + # nearest_cluster, + nearest_cluster_cells, + hvg_data, + ann_obj, + k, + ) + + if data_nearest_cells.size == 0: + logger.warning( + f"Warning: Reduced data for nearest cluster {nearest_cluster} is empty. Skipping." + ) + score.append(np.nan) + continue + + other_dist = calculate_top_k_neighbor_distances( + data_this_cells, data_nearest_cells, k - 1 + ).mean() + + score.append((other_dist - self_dist) / max(self_dist, other_dist)) + + return np.array(score) + + +def integration_score( + batch_labels: Sequence[np.ndarray], metric: str = "ari" +) -> Optional[float]: + """Calculate integration score between two sets of batch labels. + + Args: + batch_labels: Sequence containing two arrays of batch labels to compare + metric: Metric to use for comparison, one of: + - 'ari': Adjusted Rand Index + - 'nmi': Normalized Mutual Information + + Returns: + Optional[float]: Integration score between 0 and 1, or None if metric + is not recognized + + Notes: + Higher scores indicate better agreement between batch labels, + suggesting more effective batch integration. + """ + from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score + + if metric == "ari": + return adjusted_rand_score(batch_labels[0], batch_labels[1]) + elif metric == "nmi": + return normalized_mutual_info_score(batch_labels[0], batch_labels[1]) + else: + logger.error( + f"Metric {metric} not recognized. Please choose from 'ari', 'nmi', 'calinski_harabasz', or 'davies_bouldin'." + ) + return None diff --git a/scarf/tests/test_metrics.py b/scarf/tests/test_metrics.py new file mode 100644 index 0000000..6cbb277 --- /dev/null +++ b/scarf/tests/test_metrics.py @@ -0,0 +1,49 @@ +import numpy as np + + +def test_metric_lisi(datastore, make_graph): + lables = np.random.randint(0, 2, datastore.cells.N) + datastore.cells.insert( + column_name="samples_id", + values=lables, + overwrite=True, + ) + lisi = datastore.metric_lisi( + label_colnames=["samples_id"], save_result=False, return_lisi=True + ) + assert len(lisi[0][1]) == len(datastore.cells.active_index("I")) + + +def test_metric_silhouette(datastore, make_graph, leiden_clustering): + _ = datastore.metric_silhouette() + + +def test_metric_integration_ari(datastore, make_graph, leiden_clustering): + lables1 = np.random.randint(0, 2, datastore.cells.N) + lables2 = np.random.randint(0, 2, datastore.cells.N) + datastore.cells.insert( + column_name="lables1", + values=lables1, + overwrite=True, + ) + datastore.cells.insert( + column_name="lables2", + values=lables2, + overwrite=True, + ) + _ = datastore.metric_integration(batch_labels=["lables1", "lables2"], metric="ari") + +def test_metric_integration_nmi(datastore, make_graph, leiden_clustering): + lables1 = np.random.randint(0, 2, datastore.cells.N) + lables2 = np.random.randint(0, 2, datastore.cells.N) + datastore.cells.insert( + column_name="lables1", + values=lables1, + overwrite=True, + ) + datastore.cells.insert( + column_name="lables2", + values=lables2, + overwrite=True, + ) + _ = datastore.metric_integration(batch_labels=["lables1", "lables2"], metric="nmi") \ No newline at end of file