From 0446f8b2de0fbaa1fbe8f3eff899b8affd39e57b Mon Sep 17 00:00:00 2001 From: gautam8387 Date: Thu, 24 Oct 2024 00:15:29 +0530 Subject: [PATCH 1/6] Added LiSi evaulation metric and helper methods in datastore --- scarf/datastore/datastore.py | 110 +++++++-- scarf/datastore/graph_datastore.py | 35 ++- scarf/metrics.py | 369 +++++++++++++++++++++++++++++ 3 files changed, 497 insertions(+), 17 deletions(-) create mode 100644 scarf/metrics.py diff --git a/scarf/datastore/datastore.py b/scarf/datastore/datastore.py index b803ae5..e8ae235 100644 --- a/scarf/datastore/datastore.py +++ b/scarf/datastore/datastore.py @@ -1,15 +1,15 @@ -from typing import Iterable, Optional, Union, List, Literal, Tuple +from typing import Iterable, List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd from dask import array as daskarr from loguru import logger -from .mapping_datastore import MappingDatastore -from ..assay import Assay, RNAassay, ATACassay +from ..assay import Assay, ATACassay, RNAassay from ..feat_utils import hto_demux -from ..utils import tqdmbar, controlled_compute, ZARRLOC -from ..writers import create_zarr_obj_array, create_zarr_dataset +from ..utils import ZARRLOC, controlled_compute, tqdmbar +from ..writers import create_zarr_dataset, create_zarr_obj_array +from .mapping_datastore import MappingDatastore __all__ = ["DataStore"] @@ -75,10 +75,7 @@ def __init__( synchronizer=synchronizer, ) - def get_assay( - self, - assay_name: str - ) -> Assay: + def get_assay(self, assay_name: str) -> Assay: """Returns the assay object for the given assay name. Args: @@ -91,8 +88,7 @@ def get_assay( raise ValueError(f"ERROR: Assay {assay_name} not found in the Zarr file") else: return getattr(self, assay_name) - - + def filter_cells( self, attrs: Iterable[str], @@ -1496,7 +1492,7 @@ def plot_layout( # TODO: add support for providing a list of subselections, from_assay and cell_keys # TODO: add support for different kinds of point markers - from ..plots import shade_scatter, plot_scatter + from ..plots import plot_scatter, shade_scatter if from_assay is None: from_assay = self._defaultAssay @@ -1728,9 +1724,10 @@ def plot_cluster_tree( None """ - from ..plots import plot_cluster_hierarchy + from networkx import DiGraph, to_pandas_edgelist + from ..dendrogram import CoalesceTree, make_digraph - from networkx import to_pandas_edgelist, DiGraph + from ..plots import plot_cluster_hierarchy from_assay, cell_key, feat_key = self._get_latest_keys( from_assay, cell_key, feat_key @@ -2049,3 +2046,88 @@ def plot_pseudotime_heatmap( save_dpi=save_dpi, show_fig=show_fig, ) + + def metric_lisi( + self, + label_colnames: Iterable[str], + from_assay: Optional[str] = None, + cell_key: Optional[str] = None, + feat_key: Optional[str] = None, + dims: Optional[str] = None, + reduction_method: Optional[str] = None, + pca_cell_key: Optional[str] = None, + ann_metric: Optional[str] = None, + ann_efc: Optional[int] = None, + ann_ef: Optional[int] = None, + ann_m: Optional[int] = None, + rand_state: Optional[int] = 4466, + k: Optional[int] = None, + return_lisi: bool = False, + ): + """ + label_colnames: List of column names from cell metadata table that contains the ground truth labels. + """ + if None in [from_assay, cell_key, feat_key, dims, k]: + knn_loc = self.get_latest_knn_loc(from_assay) + logger.info(f"Using the latest knn graph at location: {knn_loc}") + else: + if None in [ + reduction_method, + pca_cell_key, + ann_metric, + ann_efc, + ann_ef, + ann_m, + rand_state, + ]: + raise ValueError( + "Please provide values for all the parameters: reduction_method, pca_cell_key, ann_metric, ann_efc, ann_ef, ann_m, rand_state" + ) + normed_loc = f"{from_assay}/normed__{cell_key}__{feat_key}" + reduction_loc = ( + f"{normed_loc}/reduction__{reduction_method}__{dims}__{pca_cell_key}" + ) + ann_loc = f"{reduction_loc}/ann__{ann_metric}__{ann_efc}__{ann_ef}__{ann_m}__{rand_state}" + knn_loc = f"{ann_loc}/knn__{k}" + + if knn_loc not in self.zw: + raise ValueError(f"Could not find the knn graph at location: {knn_loc}") + logger.info(f"Using the knn graph at location: {knn_loc}") + knn = self.zw[knn_loc] + + distances = knn["distances"] + indices = knn["indices"] + try: + metadata = self.cells.to_pandas_dataframe(columns=label_colnames) + except KeyError: + raise KeyError( + f"Could not find the column(s) {label_colnames} in the cell metadata table." + ) + + from ..metrics import compute_lisi + + lisi_scores = compute_lisi(distances, indices, metadata, label_colnames) + # lisi_scores Shape -> (n_cells, n_labels) + if not return_lisi: + for col, vals in zip(label_colnames, lisi_scores.T): + col_name = f"lisi__{col}__{knn_loc.split('/')[-1]}" + self.cells.insert(column_name=col_name, values=vals, overwrite=True) + return list(zip(label_colnames, lisi_scores.T)) + + def metric_silhouette( + self, + from_assay: Optional[str] = None, + cell_key: Optional[str] = None, + feat_key: Optional[str] = None, + dims: Optional[str] = None, + reduction_method: Optional[str] = None, + pca_cell_key: Optional[str] = None, + ann_metric: Optional[str] = None, + ann_efc: Optional[int] = None, + ann_ef: Optional[int] = None, + ann_m: Optional[int] = None, + rand_state: Optional[int] = 4466, + k: Optional[int] = None, + return_silhouette: bool = False, + ): + pass diff --git a/scarf/datastore/graph_datastore.py b/scarf/datastore/graph_datastore.py index 5235f94..487d3d4 100644 --- a/scarf/datastore/graph_datastore.py +++ b/scarf/datastore/graph_datastore.py @@ -1,16 +1,16 @@ import os -from typing import Tuple, Optional, Union, List, Callable +from typing import Callable, List, Optional, Tuple, Union import numpy as np import pandas as pd from dask.array import from_zarr # type: ignore from loguru import logger -from scipy.sparse import csr_matrix, coo_matrix +from scipy.sparse import coo_matrix, csr_matrix -from .base_datastore import BaseDataStore from ..assay import Assay from ..utils import clean_array, show_dask_progress, system_call, tqdmbar from ..writers import create_zarr_dataset +from .base_datastore import BaseDataStore class GraphDataStore(BaseDataStore): @@ -396,6 +396,34 @@ def _get_latest_graph_loc( knn_loc = self.zw[ann_loc].attrs["latest_knn"] return self.zw[knn_loc].attrs["latest_graph"] + def get_latest_knn_loc(self, from_assay: str = None) -> str: + """Convenience function to identify location of the latest KNN graph in + the Zarr hierarchy. + + Args: + from_assay: Name of the assay. + + Returns: + Path of KNN graph in the Zarr hierarchy + """ + if from_assay is None: + logger.info("Using default assay for KNN graph.") + from_assay = self._load_default_assay() + + if from_assay not in self.assay_names: + raise ValueError(f"ERROR: Assay {from_assay} does not exist") + + latest_cell_key = self.zw[from_assay].attrs["latest_cell_key"] + latest_feat_key = self.zw[from_assay].attrs["latest_feat_key"] + normed_loc = f"{from_assay}/normed__{latest_cell_key}__{latest_feat_key}" + reduction_loc = self.zw[normed_loc].attrs["latest_reduction"] + if "reduction" not in self.zw[reduction_loc]: + raise ValueError(f"ERROR: PCA Reduction not found in {reduction_loc}") + latest_ann = self.zw[reduction_loc].attrs["latest_ann"] + ann_loc = self.zw[latest_ann] + latest_knn = ann_loc.attrs["latest_knn"] + return latest_knn + def _get_ini_embed( self, from_assay: str, cell_key: str, feat_key: str, n_comps: int ) -> np.ndarray: @@ -414,6 +442,7 @@ def _get_ini_embed( Matrix with n_comps dimensions representing initial embedding of cells. """ from sklearn.decomposition import PCA + from ..utils import rescale_array normed_loc = f"{from_assay}/normed__{cell_key}__{feat_key}" diff --git a/scarf/metrics.py b/scarf/metrics.py new file mode 100644 index 0000000..bd7396c --- /dev/null +++ b/scarf/metrics.py @@ -0,0 +1,369 @@ +""" +Methods and classes for evluation +""" + +import math +import os +import re +from collections import Counter +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import polars as pl +import zarr +from dask.array import from_array +from dask.array.core import Array as daskArrayType +from scipy.io import mmwrite +from scipy.sparse import coo_matrix, csr_matrix +from sklearn.neighbors import NearestNeighbors +from zarr.core import Array as zarrArrayType + +from .assay import ( + ADTassay, + ATACassay, + RNAassay, +) +from .datastore.datastore import DataStore +from .metadata import MetaData +from .utils import ( + ZARRLOC, + controlled_compute, + load_zarr, + logger, + permute_into_chunks, + tqdmbar, +) +from .writers import create_zarr_count_assay, create_zarr_obj_array + + +# LISI - The Local Inverse Simpson Index +def compute_lisi( + # X: np.array, + distances: zarrArrayType, + indices: zarrArrayType, + metadata: pd.DataFrame, + label_colnames: Iterable[str], + perplexity: float = 30, +): + """Compute the Local Inverse Simpson Index (LISI) for each column in metadata. + + LISI is a statistic computed for each item (row) in the data matrix X. + + The following example may help to interpret the LISI values. + + Suppose one of the columns in metadata is a categorical variable with 3 categories. + + - If LISI is approximately equal to 3 for an item in the data matrix, + that means that the item is surrounded by neighbors from all 3 + categories. + + - If LISI is approximately equal to 1, then the item is surrounded by + neighbors from 1 category. + + The LISI statistic is useful to evaluate whether multiple datasets are + well-integrated by algorithms such as Harmony [1]. + + [1]: Korsunsky et al. 2019 doi: 10.1038/s41592-019-0619-0 + """ + # # We need at least 3 * n_neigbhors to compute the perplexity + # knn = NearestNeighbors(n_neighbors = math.ceil(perplexity * 3), algorithm = 'kd_tree').fit(X) + # distances, indices = knn.kneighbors(X) + + n_cells = metadata.shape[0] + n_labels = len(label_colnames) + # Don't count yourself + indices = indices[:, 1:] + distances = distances[:, 1:] + # Save the result + lisi_df = np.zeros((n_cells, n_labels)) + for i, label in enumerate(label_colnames): + logger.info(f"Computing LISI for {label}") + labels = pd.Categorical(metadata[label]) + n_categories = len(labels.categories) + simpson = compute_simpson( + distances.T, indices.T, labels, n_categories, perplexity + ) + lisi_df[:, i] = 1 / simpson + # lisi_df = lisi_df.flatten() + return lisi_df + + +def compute_simpson( + distances: np.ndarray, + indices: np.ndarray, + labels: pd.Categorical, + n_categories: int, + perplexity: float, + tol: float = 1e-5, +): + n = distances.shape[1] + P = np.zeros(distances.shape[0]) + simpson = np.zeros(n) + logU = np.log(perplexity) + # Loop through each cell. + for i in range(n): + beta = 1 + betamin = -np.inf + betamax = np.inf + # Compute Hdiff + P = np.exp(-distances[:, i] * beta) + P_sum = np.sum(P) + if P_sum == 0: + H = 0 + P = np.zeros(distances.shape[0]) + else: + H = np.log(P_sum) + beta * np.sum(distances[:, i] * P) / P_sum + P = P / P_sum + Hdiff = H - logU + n_tries = 50 + for t in range(n_tries): + # Stop when we reach the tolerance + if abs(Hdiff) < tol: + break + # Update beta + if Hdiff > 0: + betamin = beta + if not np.isfinite(betamax): + beta *= 2 + else: + beta = (beta + betamax) / 2 + else: + betamax = beta + if not np.isfinite(betamin): + beta /= 2 + else: + beta = (beta + betamin) / 2 + # Compute Hdiff + P = np.exp(-distances[:, i] * beta) + P_sum = np.sum(P) + if P_sum == 0: + H = 0 + P = np.zeros(distances.shape[0]) + else: + H = np.log(P_sum) + beta * np.sum(distances[:, i] * P) / P_sum + P = P / P_sum + Hdiff = H - logU + # distancesefault value + if H == 0: + simpson[i] = -1 + # Simpson's index + for label_category in labels.categories: + ix = indices[:, i] + q = labels[ix] == label_category + if np.any(q): + P_sum = np.sum(P[q]) + simpson[i] += P_sum * P_sum + return simpson + + +# SILHOUETTE SCORE - The Silhouette Score +def knn_to_csr_matrix( + neighbor_indices: np.ndarray, neighbor_distances: np.ndarray +) -> csr_matrix: + """ + Convert k-nearest neighbors data to a Compressed Sparse Row (CSR) matrix. + + Parameters: + neighbor_indices : 2D array + Indices of k-nearest neighbors for each data point. + neighbor_distances : 2D array + Distances to k-nearest neighbors for each data point. + + Returns: + scipy.sparse.csr_matrix + A sparse matrix representation of the KNN graph. + """ + num_samples, num_neighbors = neighbor_indices.shape + row_indices = np.repeat(np.arange(num_samples), num_neighbors) + return csr_matrix( + (neighbor_distances[:].flatten(), (row_indices, neighbor_indices[:].flatten())), + shape=(num_samples, num_samples), + ) + + +def calculate_weighted_cluster_similarity( + knn_graph: csr_matrix, cluster_labels: np.ndarray +) -> np.ndarray: + """ + Calculate similarity between clusters based on shared weighted edges. + + Parameters: + - knn_graph: CSR matrix representing the KNN graph + - cluster_labels: 1D array with cluster/community index for each node. Contiguous and start from 1. + + Returns: + - similarity_matrix: 2D numpy array with similarity scores between clusters + """ + unique_cluster_ids = np.unique(cluster_labels) + expected_cluster_ids = np.arange(0, len(unique_cluster_ids)) + assert np.array_equal( + unique_cluster_ids, expected_cluster_ids + ), "Cluster labels must be contiguous integers starting at 1" + + num_clusters = len(unique_cluster_ids) + inter_cluster_weights = np.zeros((num_clusters, num_clusters)) + + for cluster_id in unique_cluster_ids: + nodes_in_cluster = np.where(cluster_labels == cluster_id)[0] + neighbor_cluster_labels = cluster_labels[knn_graph[nodes_in_cluster].indices] + neighbor_edge_weights = knn_graph[nodes_in_cluster].data + + for neighbor_cluster, edge_weight in zip( + neighbor_cluster_labels, neighbor_edge_weights + ): + inter_cluster_weights[cluster_id, neighbor_cluster] += edge_weight + + assert inter_cluster_weights.sum() == knn_graph.data.sum() + + # Ensure symmetry + inter_cluster_weights = (inter_cluster_weights + inter_cluster_weights.T) / 2 + + # Calculate total weights for each cluster + total_cluster_weights = np.array( + [inter_cluster_weights[i - 1].sum() for i in unique_cluster_ids] + ) + + # Calculate similarity using weighted Jaccard index + similarity_matrix = np.zeros((num_clusters, num_clusters)) + + for i in range(num_clusters): + for j in range(i, num_clusters): + weight_union = ( + total_cluster_weights[i] + + total_cluster_weights[j] + - inter_cluster_weights[i, j] + ) + if weight_union > 0: + similarity = inter_cluster_weights[i, j] / weight_union + similarity_matrix[i, j] = similarity_matrix[j, i] = similarity + + # Set diagonal to 1 (self-similarity) + # np.fill_diagonal(similarity_matrix, 1.0) + + return similarity_matrix + + +def calculate_top_k_neighbor_distances( + matrix_a: np.ndarray, matrix_b: np.ndarray, k: int +) -> np.ndarray: + """ + Calculate the distances of the top k nearest neighbors from matrix_b for each point in matrix_a. + + Parameters: + matrix_a : numpy.ndarray + First matrix of shape (m, d) + matrix_b : numpy.ndarray + Second matrix of shape (n, d) + k : int + Number of nearest neighbors to consider + + Returns: + numpy.ndarray + Array of shape (m, k) containing the distances of the k nearest neighbors + from matrix_b for each point in matrix_a + """ + # Check if the matrices have the same number of features (d) + assert ( + matrix_a.shape[1] == matrix_b.shape[1] + ), "Matrices must have the same number of features" + + # Ensure k is not larger than the number of points in matrix_b + k = min(k, matrix_b.shape[0]) + + # Calculate squared Euclidean distances + a_squared = np.sum(np.square(matrix_a), axis=1, keepdims=True) + b_squared = np.sum(np.square(matrix_b), axis=1) + + # Use broadcasting to compute pairwise distances + distances = a_squared + b_squared - 2 * np.dot(matrix_a, matrix_b.T) + + # Use np.maximum to avoid small negative values due to floating point errors + distances = np.maximum(distances, 0) + + # Find the k smallest distances for each point in matrix_a + top_k_distances = np.partition(distances, k, axis=1)[:, :k] + + # Calculate the square root to get Euclidean distances + return np.sqrt(top_k_distances) + + +def process_cluster(cluster_cells, hvg_data, ann_obj, k): + np.random.shuffle(cluster_cells) + data_cells = np.array( + [ann_obj.reducer(hvg_data[i]) for i in sorted(cluster_cells[:k])] + ) + data_cells_2 = np.array( + [ann_obj.reducer(hvg_data[i]) for i in sorted(cluster_cells[k : 2 * k])] + ) + return data_cells, data_cells_2 + + +def silhouette_scoring(ds, ann_obj, graph, hvg_data, res_label): + clusters = ds.cells.fetch(f"RNA_{res_label}") - 1 + cluster_similarity = calculate_weighted_cluster_similarity(graph, clusters) + + k = 11 + score = [] + + for n, i in enumerate(cluster_similarity): + this_cluster_cells = np.where(clusters == n)[0] + if len(this_cluster_cells) < 2 * k: + k = int(len(this_cluster_cells) / 2) + logger.warning( + f"Warning: Cluster {n} has fewer than 22 cells. Will adjust k to {k} instead" + ) + + for n, i in tqdmbar(enumerate(cluster_similarity)): + this_cluster_cells = np.where(clusters == n)[0] + + data_this_cells, data_this_cells_2 = process_cluster( + n, this_cluster_cells, hvg_data, ann_obj, k + ) + + if data_this_cells.size == 0 or data_this_cells_2.size == 0: + logger.warning(f"Warning: Reduced data for cluster {n} is empty. Skipping.") + score.append(np.nan) + continue + + k_neighbors = min(k - 1, data_this_cells_2.shape[0] - 1) + + if k_neighbors < 1: + logger.warning( + f"Warning: Not enough points in cluster {n} for comparison. Skipping." + ) + score.append(np.nan) + continue + + self_dist = calculate_top_k_neighbor_distances( + data_this_cells, data_this_cells_2, k - 1 + ).mean() + + nearest_cluster = np.argsort(i)[-1] + nearest_cluster_cells = np.where(clusters == nearest_cluster)[0] + + if len(nearest_cluster_cells) < k: + logger.warning( + f"Warning: Nearest cluster {nearest_cluster} has fewer than {k} cells. Skipping." + ) + score.append(np.nan) + continue + + data_nearest_cells, _ = process_cluster( + nearest_cluster, nearest_cluster_cells, hvg_data, ann_obj, k + ) + + if data_nearest_cells.size == 0: + logger.warning( + f"Warning: Reduced data for nearest cluster {nearest_cluster} is empty. Skipping." + ) + score.append(np.nan) + continue + + other_dist = calculate_top_k_neighbor_distances( + data_this_cells, data_nearest_cells, k - 1 + ).mean() + + score.append((other_dist - self_dist) / max(self_dist, other_dist)) + + return np.array(score) From 7dfdf64998a656ca43d1cf0ea8d52e73801edba4 Mon Sep 17 00:00:00 2001 From: gautam8387 Date: Tue, 29 Oct 2024 04:05:35 +0530 Subject: [PATCH 2/6] datastore.py: added metrics methods for lisi, silhouette, and integration (adjusted rand score, normalized mutual information score). Uses the latest knn location when calculating default. Provide all parameter otherwise. metrics.py: function for computing all scores. graph_datastore.py: rename functions --- scarf/datastore/datastore.py | 111 ++++++++++++++++++++++++++--- scarf/datastore/graph_datastore.py | 2 +- scarf/metrics.py | 48 +++++++++++-- 3 files changed, 146 insertions(+), 15 deletions(-) diff --git a/scarf/datastore/datastore.py b/scarf/datastore/datastore.py index e8ae235..193e955 100644 --- a/scarf/datastore/datastore.py +++ b/scarf/datastore/datastore.py @@ -2050,6 +2050,7 @@ def plot_pseudotime_heatmap( def metric_lisi( self, label_colnames: Iterable[str], + use_latest_knn: bool = True, from_assay: Optional[str] = None, cell_key: Optional[str] = None, feat_key: Optional[str] = None, @@ -2062,16 +2063,22 @@ def metric_lisi( ann_m: Optional[int] = None, rand_state: Optional[int] = 4466, k: Optional[int] = None, + save_result: bool = True, return_lisi: bool = False, - ): + ) -> Union[None, List[Tuple[str, np.ndarray]]]: """ label_colnames: List of column names from cell metadata table that contains the ground truth labels. """ - if None in [from_assay, cell_key, feat_key, dims, k]: - knn_loc = self.get_latest_knn_loc(from_assay) + if use_latest_knn: + knn_loc = self._get_latest_knn_loc(from_assay) logger.info(f"Using the latest knn graph at location: {knn_loc}") else: if None in [ + from_assay, + cell_key, + feat_key, + dims, + k, reduction_method, pca_cell_key, ann_metric, @@ -2081,7 +2088,7 @@ def metric_lisi( rand_state, ]: raise ValueError( - "Please provide values for all the parameters: reduction_method, pca_cell_key, ann_metric, ann_efc, ann_ef, ann_m, rand_state" + "Please provide values for all the parameters: from_assay, cell_key, feat_key, dims, k, reduction_method, pca_cell_key, ann_metric, ann_efc, ann_ef, ann_m, rand_state" ) normed_loc = f"{from_assay}/normed__{cell_key}__{feat_key}" reduction_loc = ( @@ -2093,6 +2100,7 @@ def metric_lisi( if knn_loc not in self.zw: raise ValueError(f"Could not find the knn graph at location: {knn_loc}") logger.info(f"Using the knn graph at location: {knn_loc}") + knn = self.zw[knn_loc] distances = knn["distances"] @@ -2108,14 +2116,20 @@ def metric_lisi( lisi_scores = compute_lisi(distances, indices, metadata, label_colnames) # lisi_scores Shape -> (n_cells, n_labels) - if not return_lisi: + if save_result: for col, vals in zip(label_colnames, lisi_scores.T): col_name = f"lisi__{col}__{knn_loc.split('/')[-1]}" self.cells.insert(column_name=col_name, values=vals, overwrite=True) - return list(zip(label_colnames, lisi_scores.T)) + + if return_lisi: + return list(zip(label_colnames, lisi_scores.T)) + else: + return None def metric_silhouette( self, + use_latest_knn: bool = True, + res_label: str = "leiden_cluster", from_assay: Optional[str] = None, cell_key: Optional[str] = None, feat_key: Optional[str] = None, @@ -2128,6 +2142,87 @@ def metric_silhouette( ann_m: Optional[int] = None, rand_state: Optional[int] = 4466, k: Optional[int] = None, - return_silhouette: bool = False, ): - pass + """ + label_colnames: List of column names from cell metadata table that contains the ground truth labels. + """ + if use_latest_knn: + knn_loc = self._get_latest_knn_loc(from_assay) + from_assay = self._load_default_assay() + logger.info(f"Using the latest knn graph at location: {knn_loc}") + k = knn_loc.rsplit("/", 1)[-1].split("__")[-1] + dims = knn_loc.rsplit("/", 2)[0].split("__")[-2] + feat_key = knn_loc.split("/")[1].split("__")[-1] + + else: + if None in [ + from_assay, + cell_key, + feat_key, + dims, + k, + reduction_method, + pca_cell_key, + ann_metric, + ann_efc, + ann_ef, + ann_m, + rand_state, + ]: + raise ValueError( + "Please provide values for all the parameters: from_assay, cell_key, feat_key, dims, k, reduction_method, pca_cell_key, ann_metric, ann_efc, ann_ef, ann_m, rand_state" + ) + normed_loc = f"{from_assay}/normed__{cell_key}__{feat_key}" + reduction_loc = ( + f"{normed_loc}/reduction__{reduction_method}__{dims}__{pca_cell_key}" + ) + ann_loc = f"{reduction_loc}/ann__{ann_metric}__{ann_efc}__{ann_ef}__{ann_m}__{rand_state}" + knn_loc = f"{ann_loc}/knn__{k}" + + if knn_loc not in self.zw: + raise ValueError(f"Could not find the knn graph at location: {knn_loc}") + logger.info(f"Using the knn graph at location: {knn_loc}") + + from ..metrics import silhouette_scoring, knn_to_csr_matrix + + isHarmonized = self.zw[knn_loc.rsplit("/", 1)[0]].attrs["isHarmonized"] + batches = None + if isHarmonized: + batches = self.zw[knn_loc.rsplit("/", 2)[0] + "/harmonizedData"].attrs[ + "batches" + ] + + ann_obj = self.make_graph( + feat_key=feat_key, + dims=dims, + k=k, + return_ann_object=True, + harmonize=isHarmonized, + batch_columns=batches, + ) + graph = knn_to_csr_matrix(self.z[knn_loc].indices, self.z[knn_loc].distances) + # if isHarmonized and harmonize: + # logger.info("Using harmonized data for silhouette scoring") + # hvg_data = self.z[knn_loc.rsplit("/", 2)[0] + "/harmonizedData"] + # else: + # hvg_data = self.z[knn_loc.rsplit("/", 3)[0] + "/data"] + hvg_data = self.z[knn_loc.rsplit("/", 3)[0] + "/data"] + scores = silhouette_scoring( + self, ann_obj, graph, hvg_data, from_assay, res_label + ) + return scores + + + def metric_integration(self, batch_labels: List[str], metric: str = "ari"): + """ + label_colnames: List of column names from cell metadata table that contains the ground truth labels. + """ + from ..metrics import integration_score + + batch_labels_vals = [] + for batch in batch_labels: + vals = np.array(self.cells.fetch_all(batch)) + batch_labels_vals.append(vals) + + scores = integration_score(batch_labels_vals, metric) + return scores diff --git a/scarf/datastore/graph_datastore.py b/scarf/datastore/graph_datastore.py index 487d3d4..7ba0b1d 100644 --- a/scarf/datastore/graph_datastore.py +++ b/scarf/datastore/graph_datastore.py @@ -396,7 +396,7 @@ def _get_latest_graph_loc( knn_loc = self.zw[ann_loc].attrs["latest_knn"] return self.zw[knn_loc].attrs["latest_graph"] - def get_latest_knn_loc(self, from_assay: str = None) -> str: + def _get_latest_knn_loc(self, from_assay: str = None) -> str: """Convenience function to identify location of the latest KNN graph in the Zarr hierarchy. diff --git a/scarf/metrics.py b/scarf/metrics.py index bd7396c..9856037 100644 --- a/scarf/metrics.py +++ b/scarf/metrics.py @@ -299,8 +299,13 @@ def process_cluster(cluster_cells, hvg_data, ann_obj, k): return data_cells, data_cells_2 -def silhouette_scoring(ds, ann_obj, graph, hvg_data, res_label): - clusters = ds.cells.fetch(f"RNA_{res_label}") - 1 +def silhouette_scoring(ds, ann_obj, graph, hvg_data, assay_type, res_label): + try: + clusters = ds.cells.fetch(f"{assay_type}_{res_label}") - 1 # RNA_{res_label} + except KeyError: + logger.error(f"Cluster labels not found for {assay_type}_{res_label}") + return None + cluster_similarity = calculate_weighted_cluster_similarity(graph, clusters) k = 11 @@ -314,11 +319,15 @@ def silhouette_scoring(ds, ann_obj, graph, hvg_data, res_label): f"Warning: Cluster {n} has fewer than 22 cells. Will adjust k to {k} instead" ) - for n, i in tqdmbar(enumerate(cluster_similarity)): + for n, i in tqdmbar(enumerate(cluster_similarity), total=len(cluster_similarity)): this_cluster_cells = np.where(clusters == n)[0] - + np.random.shuffle(this_cluster_cells) data_this_cells, data_this_cells_2 = process_cluster( - n, this_cluster_cells, hvg_data, ann_obj, k + # n, + this_cluster_cells, + hvg_data, + ann_obj, + k, ) if data_this_cells.size == 0 or data_this_cells_2.size == 0: @@ -341,6 +350,7 @@ def silhouette_scoring(ds, ann_obj, graph, hvg_data, res_label): nearest_cluster = np.argsort(i)[-1] nearest_cluster_cells = np.where(clusters == nearest_cluster)[0] + np.random.shuffle(nearest_cluster_cells) if len(nearest_cluster_cells) < k: logger.warning( @@ -350,7 +360,11 @@ def silhouette_scoring(ds, ann_obj, graph, hvg_data, res_label): continue data_nearest_cells, _ = process_cluster( - nearest_cluster, nearest_cluster_cells, hvg_data, ann_obj, k + # nearest_cluster, + nearest_cluster_cells, + hvg_data, + ann_obj, + k, ) if data_nearest_cells.size == 0: @@ -367,3 +381,25 @@ def silhouette_scoring(ds, ann_obj, graph, hvg_data, res_label): score.append((other_dist - self_dist) / max(self_dist, other_dist)) return np.array(score) + + +def integration_score( + batch_labels: np.ndarray, + metric: str = 'ari' +): + from sklearn.metrics import adjusted_rand_score + from sklearn.metrics import normalized_mutual_info_score + # from sklearn.metrics import calinski_harabasz_score + # from sklearn.metrics import davies_bouldin_score + + if metric == 'ari': + return adjusted_rand_score(batch_labels[0], batch_labels[1]) + elif metric == 'nmi': + return normalized_mutual_info_score(batch_labels[0], batch_labels[1]) + # elif metric == 'calinski_harabasz': + # return calinski_harabasz_score(batch_labels[0], batch_labels[1]) + # elif metric == 'davies_bouldin': + # return davies_bouldin_score(batch_labels[0], batch_labels[1]) + else: + logger.error(f"Metric {metric} not recognized. Please choose from 'ari', 'nmi', 'calinski_harabasz', or 'davies_bouldin'.") + return None \ No newline at end of file From c90134cb5d863236dea7c806119c101fd58edfbd Mon Sep 17 00:00:00 2001 From: gautam8387 Date: Tue, 29 Oct 2024 05:42:28 +0530 Subject: [PATCH 3/6] All: - Added DocString and typing datastore.py: - lisi: filtered metadata as per 'I' - doc strings and typing metrics.py: - formatted & typing tests: - Added test for metrics --- scarf/datastore/datastore.py | 137 ++++++++++++++++---- scarf/metrics.py | 235 ++++++++++++++++++++++------------- scarf/tests/test_metrics.py | 37 ++++++ 3 files changed, 299 insertions(+), 110 deletions(-) create mode 100644 scarf/tests/test_metrics.py diff --git a/scarf/datastore/datastore.py b/scarf/datastore/datastore.py index 193e955..76dae9a 100644 --- a/scarf/datastore/datastore.py +++ b/scarf/datastore/datastore.py @@ -621,8 +621,8 @@ def get_markers( cell_key = "I" if group_key is None: raise ValueError( - f"ERROR: Please provide a value for group_key. " - f"This should be same as used for `run_marker_search`" + "ERROR: Please provide a value for group_key. " + "This should be same as used for `run_marker_search`" ) assay = self._get_assay(from_assay) try: @@ -702,8 +702,8 @@ def export_markers_to_csv( # Not testing the values of from_assay and cell_key because they will be tested in `get_markers` if group_key is None: raise ValueError( - f"ERROR: Please provide a value for group_key. " - f"This should be same as used for `run_marker_search`" + "ERROR: Please provide a value for group_key. " + "This should be same as used for `run_marker_search`" ) if csv_filename is None: raise ValueError( @@ -2008,8 +2008,8 @@ def plot_pseudotime_heatmap( else: if hashes != assay.z[location].attrs["hashes"]: raise ValueError( - f"ERROR: The values under one or more of these columns: `cell_key`, `feat_key` or/and " - f"`pseudotime_key have been updated after running `run_pseudotime_aggregation`" + "ERROR: The values under one or more of these columns: `cell_key`, `feat_key` or/and " + "`pseudotime_key have been updated after running `run_pseudotime_aggregation`" ) da = daskarr.from_zarr(assay.z[location + "/data"], inline_array=True) @@ -2065,12 +2065,49 @@ def metric_lisi( k: Optional[int] = None, save_result: bool = True, return_lisi: bool = False, - ) -> Union[None, List[Tuple[str, np.ndarray]]]: - """ - label_colnames: List of column names from cell metadata table that contains the ground truth labels. + ) -> Optional[List[Tuple[str, np.ndarray]]]: + """Calculate Local Inverse Simpson Index (LISI) scores for cell populations. + + LISI measures how well mixed different cell populations are in the local neighborhood + of each cell. Higher scores indicate better mixing of different populations. + + Args: + label_colnames: Column names from cell metadata containing population labels + use_latest_knn: Whether to use the most recent KNN graph (default: True) + from_assay: Name of assay to use if not using latest KNN + cell_key: Cell filtering key for normalization + feat_key: Feature selection key for normalization + dims: Number of dimensions used for reduction + reduction_method: Name of dimensionality reduction method + pca_cell_key: Cell key used for PCA + ann_metric: Metric used for approximate nearest neighbors + ann_efc: Construction time/accuracy trade-off for ANN index + ann_ef: Query time/accuracy trade-off for ANN index + ann_m: Max number of connections in ANN graph + rand_state: Random seed for reproducibility (default: 4466) + k: Number of nearest neighbors + save_result: Whether to save LISI scores to cell metadata (default: True) + return_lisi: Whether to return LISI scores (default: False) + + Returns: + If return_lisi is True, returns list of tuples containing: + - Label column name + - numpy array of LISI scores for that label + If return_lisi is False, returns None + + Raises: + ValueError: If using custom KNN graph but required parameters are missing + KeyError: If label columns not found in cell metadata + + Notes: + LISI scores are computed for each label column separately. + Scores near 1 indicate cells grouped with similar labels. + Higher scores indicate more mixing between different labels. """ + if use_latest_knn: knn_loc = self._get_latest_knn_loc(from_assay) + cell_key = self.zw[self._load_default_assay()].attrs["latest_cell_key"] logger.info(f"Using the latest knn graph at location: {knn_loc}") else: if None in [ @@ -2106,7 +2143,10 @@ def metric_lisi( distances = knn["distances"] indices = knn["indices"] try: - metadata = self.cells.to_pandas_dataframe(columns=label_colnames) + metadata = self.cells.to_pandas_dataframe( + columns=label_colnames + [cell_key] + ) + metadata = metadata[metadata[cell_key]] except KeyError: raise KeyError( f"Could not find the column(s) {label_colnames} in the cell metadata table." @@ -2142,10 +2182,44 @@ def metric_silhouette( ann_m: Optional[int] = None, rand_state: Optional[int] = 4466, k: Optional[int] = None, - ): - """ - label_colnames: List of column names from cell metadata table that contains the ground truth labels. + ) -> Optional[np.ndarray]: + """Calculate modified silhouette scores for evaluating cluster separation. + + This implements a graph-based silhouette score that measures how similar cells + are to their own cluster compared to the nearest neighboring cluster. + + Args: + use_latest_knn: Whether to use most recent KNN graph (default: True) + res_label: Column name containing cluster labels (default: "leiden_cluster") + from_assay: Name of assay to use if not using latest KNN + cell_key: Cell filtering key for normalization + feat_key: Feature selection key for normalization + dims: Number of dimensions used for reduction + reduction_method: Name of dimensionality reduction method + pca_cell_key: Cell key used for PCA + ann_metric: Metric used for approximate nearest neighbors + ann_efc: Construction time/accuracy trade-off for ANN index + ann_ef: Query time/accuracy trade-off for ANN index + ann_m: Max number of connections in ANN graph + rand_state: Random seed for reproducibility (default: 4466) + k: Number of nearest neighbors + + Returns: + numpy array of silhouette scores for each cluster, or None if computation fails + + Raises: + ValueError: If using custom KNN graph but required parameters are missing + + Notes: + Scores range from -1 to 1: + - Near 1: Cluster is well-separated from neighboring clusters + - Near 0: Cluster overlaps with neighboring clusters + - Near -1: Cluster may be incorrectly assigned + + Implementation uses sampling for efficiency with large datasets. + NaN values indicate clusters that couldn't be scored due to size constraints. """ + if use_latest_knn: knn_loc = self._get_latest_knn_loc(from_assay) from_assay = self._load_default_assay() @@ -2183,7 +2257,7 @@ def metric_silhouette( raise ValueError(f"Could not find the knn graph at location: {knn_loc}") logger.info(f"Using the knn graph at location: {knn_loc}") - from ..metrics import silhouette_scoring, knn_to_csr_matrix + from ..metrics import knn_to_csr_matrix, silhouette_scoring isHarmonized = self.zw[knn_loc.rsplit("/", 1)[0]].attrs["isHarmonized"] batches = None @@ -2201,21 +2275,38 @@ def metric_silhouette( batch_columns=batches, ) graph = knn_to_csr_matrix(self.z[knn_loc].indices, self.z[knn_loc].distances) - # if isHarmonized and harmonize: - # logger.info("Using harmonized data for silhouette scoring") - # hvg_data = self.z[knn_loc.rsplit("/", 2)[0] + "/harmonizedData"] - # else: - # hvg_data = self.z[knn_loc.rsplit("/", 3)[0] + "/data"] + hvg_data = self.z[knn_loc.rsplit("/", 3)[0] + "/data"] + scores = silhouette_scoring( self, ann_obj, graph, hvg_data, from_assay, res_label ) return scores + def metric_integration( + self, batch_labels: List[str], metric: Literal["ari", "nmi"] = "ari" + ) -> Optional[float]: + """Calculate integration score between different batch labels. - def metric_integration(self, batch_labels: List[str], metric: str = "ari"): - """ - label_colnames: List of column names from cell metadata table that contains the ground truth labels. + Measures how well aligned different batches are after integration by comparing + their cluster assignments using standard clustering metrics. + + Args: + batch_labels: List of column names containing batch labels to compare + metric: Metric to use for comparison (default: "ari") + - "ari": Adjusted Rand Index + - "nmi": Normalized Mutual Information + + Returns: + Integration score between 0 and 1, or None if metric is not recognized + Higher scores indicate better alignment between batches + + Notes: + ARI and NMI measure the agreement between different batch labelings: + - Score near 1: Batches are well integrated + - Score near 0: Batches show poor integration + + ARI is adjusted for chance and generally more stringent than NMI. """ from ..metrics import integration_score @@ -2223,6 +2314,6 @@ def metric_integration(self, batch_labels: List[str], metric: str = "ari"): for batch in batch_labels: vals = np.array(self.cells.fetch_all(batch)) batch_labels_vals.append(vals) - + scores = integration_score(batch_labels_vals, metric) return scores diff --git a/scarf/metrics.py b/scarf/metrics.py index 9856037..3809bbc 100644 --- a/scarf/metrics.py +++ b/scarf/metrics.py @@ -2,39 +2,19 @@ Methods and classes for evluation """ -import math -import os -import re -from collections import Counter -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Iterable, Optional, Sequence, Tuple, Union import numpy as np import pandas as pd -import polars as pl -import zarr -from dask.array import from_array -from dask.array.core import Array as daskArrayType -from scipy.io import mmwrite -from scipy.sparse import coo_matrix, csr_matrix -from sklearn.neighbors import NearestNeighbors +from scipy.sparse import csr_matrix from zarr.core import Array as zarrArrayType -from .assay import ( - ADTassay, - ATACassay, - RNAassay, -) +from .ann import AnnStream from .datastore.datastore import DataStore -from .metadata import MetaData from .utils import ( - ZARRLOC, - controlled_compute, - load_zarr, logger, - permute_into_chunks, tqdmbar, ) -from .writers import create_zarr_count_assay, create_zarr_obj_array # LISI - The Local Inverse Simpson Index @@ -45,30 +25,31 @@ def compute_lisi( metadata: pd.DataFrame, label_colnames: Iterable[str], perplexity: float = 30, -): +) -> np.ndarray: """Compute the Local Inverse Simpson Index (LISI) for each column in metadata. - LISI is a statistic computed for each item (row) in the data matrix X. - - The following example may help to interpret the LISI values. + LISI measures how well mixed different groups of cells are in the neighborhood of each cell. + Higher values indicate better mixing of different groups. - Suppose one of the columns in metadata is a categorical variable with 3 categories. + Args: + distances: Pre-computed distances between cells, stored in zarr array format + indices: Pre-computed nearest neighbor indices, stored in zarr array format + metadata: DataFrame containing categorical labels for each cell + label_colnames: Column names in metadata to compute LISI for + perplexity: Parameter controlling the effective number of neighbors (default: 30) - - If LISI is approximately equal to 3 for an item in the data matrix, - that means that the item is surrounded by neighbors from all 3 - categories. - - - If LISI is approximately equal to 1, then the item is surrounded by - neighbors from 1 category. + Returns: + np.ndarray: Matrix of LISI scores with shape (n_cells, n_labels) + Each column corresponds to LISI scores for one label column in metadata - The LISI statistic is useful to evaluate whether multiple datasets are - well-integrated by algorithms such as Harmony [1]. + Example: + For metadata with a 'batch' column having 3 categories: + - LISI ≈ 3: Cell has neighbors from all 3 batches (well mixed) + - LISI ≈ 1: Cell has neighbors from only 1 batch (poorly mixed) - [1]: Korsunsky et al. 2019 doi: 10.1038/s41592-019-0619-0 + References: + Korsunsky et al. 2019 doi: 10.1038/s41592-019-0619-0 """ - # # We need at least 3 * n_neigbhors to compute the perplexity - # knn = NearestNeighbors(n_neighbors = math.ceil(perplexity * 3), algorithm = 'kd_tree').fit(X) - # distances, indices = knn.kneighbors(X) n_cells = metadata.shape[0] n_labels = len(label_colnames) @@ -96,7 +77,23 @@ def compute_simpson( n_categories: int, perplexity: float, tol: float = 1e-5, -): +) -> np.ndarray: + """Compute Simpson's diversity index with Gaussian kernel weighting. + + This function implements the core calculation for LISI, computing a diversity score + based on the distribution of categories in each cell's neighborhood. + + Args: + distances: Distance matrix between points, shape (n_neighbors, n_points) + indices: Index matrix for nearest neighbors, shape (n_neighbors, n_points) + labels: Categorical labels for each point + n_categories: Number of unique categories in labels + perplexity: Target perplexity for Gaussian kernel + tol: Convergence tolerance for perplexity calibration (default: 1e-5) + + Returns: + np.ndarray: Array of Simpson's diversity indices, one per point + """ n = distances.shape[1] P = np.zeros(distances.shape[0]) simpson = np.zeros(n) @@ -161,18 +158,18 @@ def compute_simpson( def knn_to_csr_matrix( neighbor_indices: np.ndarray, neighbor_distances: np.ndarray ) -> csr_matrix: - """ - Convert k-nearest neighbors data to a Compressed Sparse Row (CSR) matrix. + """Convert k-nearest neighbors data to a Compressed Sparse Row (CSR) matrix. + + Creates a sparse adjacency matrix representation of the k-nearest neighbors graph + where edge weights are the distances between points. - Parameters: - neighbor_indices : 2D array - Indices of k-nearest neighbors for each data point. - neighbor_distances : 2D array - Distances to k-nearest neighbors for each data point. + Args: + neighbor_indices: Indices matrix from k-nearest neighbors, shape (n_samples, k) + neighbor_distances: Distances matrix from k-nearest neighbors, shape (n_samples, k) Returns: - scipy.sparse.csr_matrix - A sparse matrix representation of the KNN graph. + scipy.sparse.csr_matrix: Sparse adjacency matrix of shape (n_samples, n_samples) + where non-zero entries represent distances between neighboring points """ num_samples, num_neighbors = neighbor_indices.shape row_indices = np.repeat(np.arange(num_samples), num_neighbors) @@ -185,15 +182,21 @@ def knn_to_csr_matrix( def calculate_weighted_cluster_similarity( knn_graph: csr_matrix, cluster_labels: np.ndarray ) -> np.ndarray: - """ - Calculate similarity between clusters based on shared weighted edges. + """Calculate similarity between clusters based on shared weighted edges. + + Uses a weighted Jaccard index to compute similarities between clusters in a KNN graph. - Parameters: - - knn_graph: CSR matrix representing the KNN graph - - cluster_labels: 1D array with cluster/community index for each node. Contiguous and start from 1. + Args: + knn_graph: CSR matrix representing the KNN graph, shape (n_samples, n_samples) + cluster_labels: Cluster assignments for each node, must be contiguous integers + starting from 0 Returns: - - similarity_matrix: 2D numpy array with similarity scores between clusters + np.ndarray: Symmetric matrix of shape (n_clusters, n_clusters) containing + pairwise similarities between clusters + + Raises: + AssertionError: If cluster labels are not contiguous integers starting from 0 """ unique_cluster_ids = np.unique(cluster_labels) expected_cluster_ids = np.arange(0, len(unique_cluster_ids)) @@ -247,21 +250,22 @@ def calculate_weighted_cluster_similarity( def calculate_top_k_neighbor_distances( matrix_a: np.ndarray, matrix_b: np.ndarray, k: int ) -> np.ndarray: - """ - Calculate the distances of the top k nearest neighbors from matrix_b for each point in matrix_a. + """Calculate distances to k nearest neighbors between two sets of points. - Parameters: - matrix_a : numpy.ndarray - First matrix of shape (m, d) - matrix_b : numpy.ndarray - Second matrix of shape (n, d) - k : int - Number of nearest neighbors to consider + For each point in matrix_a, finds the k nearest neighbors in matrix_b + and returns their distances. + + Args: + matrix_a: First set of points, shape (m, d) + matrix_b: Second set of points, shape (n, d) + k: Number of nearest neighbors to find Returns: - numpy.ndarray - Array of shape (m, k) containing the distances of the k nearest neighbors - from matrix_b for each point in matrix_a + np.ndarray: Matrix of shape (m, k) containing the distances to the + k nearest neighbors in matrix_b for each point in matrix_a + + Raises: + AssertionError: If matrices don't have the same number of features """ # Check if the matrices have the same number of features (d) assert ( @@ -288,7 +292,26 @@ def calculate_top_k_neighbor_distances( return np.sqrt(top_k_distances) -def process_cluster(cluster_cells, hvg_data, ann_obj, k): +def process_cluster( + cluster_cells: np.ndarray, + hvg_data: Union[np.ndarray, zarrArrayType], + ann_obj: AnnStream, + k: int, +) -> Tuple[np.ndarray, np.ndarray]: + """Process a cluster of cells to prepare data for silhouette scoring. + + Randomly splits cluster cells into two groups and applies dimensionality reduction. + + Args: + cluster_cells: Indices of cells belonging to the cluster + hvg_data: Expression data for highly variable genes + ann_obj: Object containing dimensionality reduction method + k: Number of cells to sample from cluster + + Returns: + Tuple[np.ndarray, np.ndarray]: Two arrays containing reduced data for + different subsets of cells from the cluster + """ np.random.shuffle(cluster_cells) data_cells = np.array( [ann_obj.reducer(hvg_data[i]) for i in sorted(cluster_cells[:k])] @@ -299,9 +322,37 @@ def process_cluster(cluster_cells, hvg_data, ann_obj, k): return data_cells, data_cells_2 -def silhouette_scoring(ds, ann_obj, graph, hvg_data, assay_type, res_label): +def silhouette_scoring( + ds: DataStore, + ann_obj: AnnStream, + graph: csr_matrix, + hvg_data: Union[np.ndarray, zarrArrayType], + assay_type: str, + res_label: str, +) -> Optional[np.ndarray]: + """Compute modified silhouette scores for clusters in single-cell data. + + This implementation differs from the standard silhouette score by using + a graph-based approach and comparing clusters to their nearest neighbors. + + Args: + ds: DataStore object containing cell metadata + ann_obj: Object containing dimensionality reduction method + graph: CSR matrix representing the KNN graph + hvg_data: Expression data for highly variable genes + assay_type: Type of assay (e.g., 'RNA', 'ATAC') + res_label: Label for clustering resolution + + Returns: + Optional[np.ndarray]: Array of silhouette scores for each cluster, + or None if cluster labels are not found + + Notes: + Scores are calculated using a sampling approach for efficiency. + NaN values indicate clusters that couldn't be scored due to size constraints. + """ try: - clusters = ds.cells.fetch(f"{assay_type}_{res_label}") - 1 # RNA_{res_label} + clusters = ds.cells.fetch(f"{assay_type}_{res_label}") - 1 # RNA_{res_label} except KeyError: logger.error(f"Cluster labels not found for {assay_type}_{res_label}") return None @@ -384,22 +435,32 @@ def silhouette_scoring(ds, ann_obj, graph, hvg_data, assay_type, res_label): def integration_score( - batch_labels: np.ndarray, - metric: str = 'ari' -): - from sklearn.metrics import adjusted_rand_score - from sklearn.metrics import normalized_mutual_info_score - # from sklearn.metrics import calinski_harabasz_score - # from sklearn.metrics import davies_bouldin_score - - if metric == 'ari': + batch_labels: Sequence[np.ndarray], metric: str = "ari" +) -> Optional[float]: + """Calculate integration score between two sets of batch labels. + + Args: + batch_labels: Sequence containing two arrays of batch labels to compare + metric: Metric to use for comparison, one of: + - 'ari': Adjusted Rand Index + - 'nmi': Normalized Mutual Information + + Returns: + Optional[float]: Integration score between 0 and 1, or None if metric + is not recognized + + Notes: + Higher scores indicate better agreement between batch labels, + suggesting more effective batch integration. + """ + from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score + + if metric == "ari": return adjusted_rand_score(batch_labels[0], batch_labels[1]) - elif metric == 'nmi': + elif metric == "nmi": return normalized_mutual_info_score(batch_labels[0], batch_labels[1]) - # elif metric == 'calinski_harabasz': - # return calinski_harabasz_score(batch_labels[0], batch_labels[1]) - # elif metric == 'davies_bouldin': - # return davies_bouldin_score(batch_labels[0], batch_labels[1]) else: - logger.error(f"Metric {metric} not recognized. Please choose from 'ari', 'nmi', 'calinski_harabasz', or 'davies_bouldin'.") - return None \ No newline at end of file + logger.error( + f"Metric {metric} not recognized. Please choose from 'ari', 'nmi', 'calinski_harabasz', or 'davies_bouldin'." + ) + return None diff --git a/scarf/tests/test_metrics.py b/scarf/tests/test_metrics.py new file mode 100644 index 0000000..a2ddc46 --- /dev/null +++ b/scarf/tests/test_metrics.py @@ -0,0 +1,37 @@ +import numpy as np + + +def test_metric_lisi(datastore, make_graph): + # datastore.auto_filter_cells(show_qc_plots=False) + # datastore.mark_hvgs(top_n=100, show_plot=False) + # datastore.make_graph(feat_key="hvgs") + lables = np.random.randint(0, 2, datastore.cells.N) + datastore.cells.insert( + column_name="samples_id", + values=lables, + overwrite=True, + ) + lisi = datastore.metric_lisi( + label_colnames=["samples_id"], save_result=False, return_lisi=True + ) + assert len(lisi[0][1]) == len(datastore.cells.active_index("I")) + + +def test_metric_silhouette(datastore, make_graph, leiden_clustering): + _ = datastore.metric_silhouette() + + +def test_metric_integration(datastore, make_graph, leiden_clustering): + lables1 = np.random.randint(0, 2, datastore.cells.N) + lables2 = np.random.randint(0, 2, datastore.cells.N) + datastore.cells.insert( + column_name="lables1", + values=lables1, + overwrite=True, + ) + datastore.cells.insert( + column_name="lables2", + values=lables2, + overwrite=True, + ) + _ = datastore.metric_integration(batch_labels=["lables1", "lables2"], metric="ari") From 9fccff9b790864d005ed8175594ef4729fd57f02 Mon Sep 17 00:00:00 2001 From: gautam8387 Date: Tue, 29 Oct 2024 05:47:57 +0530 Subject: [PATCH 4/6] comment cleanup --- scarf/metrics.py | 3 --- scarf/tests/test_metrics.py | 20 ++++++++++++++++---- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/scarf/metrics.py b/scarf/metrics.py index 3809bbc..201bfa8 100644 --- a/scarf/metrics.py +++ b/scarf/metrics.py @@ -19,7 +19,6 @@ # LISI - The Local Inverse Simpson Index def compute_lisi( - # X: np.array, distances: zarrArrayType, indices: zarrArrayType, metadata: pd.DataFrame, @@ -56,7 +55,6 @@ def compute_lisi( # Don't count yourself indices = indices[:, 1:] distances = distances[:, 1:] - # Save the result lisi_df = np.zeros((n_cells, n_labels)) for i, label in enumerate(label_colnames): logger.info(f"Computing LISI for {label}") @@ -66,7 +64,6 @@ def compute_lisi( distances.T, indices.T, labels, n_categories, perplexity ) lisi_df[:, i] = 1 / simpson - # lisi_df = lisi_df.flatten() return lisi_df diff --git a/scarf/tests/test_metrics.py b/scarf/tests/test_metrics.py index a2ddc46..6cbb277 100644 --- a/scarf/tests/test_metrics.py +++ b/scarf/tests/test_metrics.py @@ -2,9 +2,6 @@ def test_metric_lisi(datastore, make_graph): - # datastore.auto_filter_cells(show_qc_plots=False) - # datastore.mark_hvgs(top_n=100, show_plot=False) - # datastore.make_graph(feat_key="hvgs") lables = np.random.randint(0, 2, datastore.cells.N) datastore.cells.insert( column_name="samples_id", @@ -21,7 +18,7 @@ def test_metric_silhouette(datastore, make_graph, leiden_clustering): _ = datastore.metric_silhouette() -def test_metric_integration(datastore, make_graph, leiden_clustering): +def test_metric_integration_ari(datastore, make_graph, leiden_clustering): lables1 = np.random.randint(0, 2, datastore.cells.N) lables2 = np.random.randint(0, 2, datastore.cells.N) datastore.cells.insert( @@ -35,3 +32,18 @@ def test_metric_integration(datastore, make_graph, leiden_clustering): overwrite=True, ) _ = datastore.metric_integration(batch_labels=["lables1", "lables2"], metric="ari") + +def test_metric_integration_nmi(datastore, make_graph, leiden_clustering): + lables1 = np.random.randint(0, 2, datastore.cells.N) + lables2 = np.random.randint(0, 2, datastore.cells.N) + datastore.cells.insert( + column_name="lables1", + values=lables1, + overwrite=True, + ) + datastore.cells.insert( + column_name="lables2", + values=lables2, + overwrite=True, + ) + _ = datastore.metric_integration(batch_labels=["lables1", "lables2"], metric="nmi") \ No newline at end of file From da1f41490ac5bfa422fa0d96a208912b72ce903c Mon Sep 17 00:00:00 2001 From: gautam8387 Date: Thu, 7 Nov 2024 04:26:28 +0530 Subject: [PATCH 5/6] Added progress bar on Lisi --- scarf/datastore/datastore.py | 4 ++-- scarf/metrics.py | 10 ++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/scarf/datastore/datastore.py b/scarf/datastore/datastore.py index 76dae9a..7970cf0 100644 --- a/scarf/datastore/datastore.py +++ b/scarf/datastore/datastore.py @@ -2063,8 +2063,8 @@ def metric_lisi( ann_m: Optional[int] = None, rand_state: Optional[int] = 4466, k: Optional[int] = None, - save_result: bool = True, - return_lisi: bool = False, + save_result: bool = False, + return_lisi: bool = True, ) -> Optional[List[Tuple[str, np.ndarray]]]: """Calculate Local Inverse Simpson Index (LISI) scores for cell populations. diff --git a/scarf/metrics.py b/scarf/metrics.py index 201bfa8..655491e 100644 --- a/scarf/metrics.py +++ b/scarf/metrics.py @@ -71,7 +71,6 @@ def compute_simpson( distances: np.ndarray, indices: np.ndarray, labels: pd.Categorical, - n_categories: int, perplexity: float, tol: float = 1e-5, ) -> np.ndarray: @@ -84,7 +83,6 @@ def compute_simpson( distances: Distance matrix between points, shape (n_neighbors, n_points) indices: Index matrix for nearest neighbors, shape (n_neighbors, n_points) labels: Categorical labels for each point - n_categories: Number of unique categories in labels perplexity: Target perplexity for Gaussian kernel tol: Convergence tolerance for perplexity calibration (default: 1e-5) @@ -96,7 +94,7 @@ def compute_simpson( simpson = np.zeros(n) logU = np.log(perplexity) # Loop through each cell. - for i in range(n): + for i in tqdmbar(range(n), desc="Computing Simpson's Diversity Index"): beta = 1 betamin = -np.inf betamax = np.inf @@ -367,7 +365,11 @@ def silhouette_scoring( f"Warning: Cluster {n} has fewer than 22 cells. Will adjust k to {k} instead" ) - for n, i in tqdmbar(enumerate(cluster_similarity), total=len(cluster_similarity)): + for n, i in tqdmbar( + enumerate(cluster_similarity), + total=len(cluster_similarity), + desc="Calculating Silhouette Scores", + ): this_cluster_cells = np.where(clusters == n)[0] np.random.shuffle(this_cluster_cells) data_this_cells, data_this_cells_2 = process_cluster( From 60ae0b15d9eee8f11e737d4a40ab3a27f82190cf Mon Sep 17 00:00:00 2001 From: gautam8387 Date: Tue, 14 Jan 2025 04:28:09 +0530 Subject: [PATCH 6/6] datastore.py: updated metric_lisi, metric_silhouette, and metric_integration to use latest KNN location and option to provide KNN location as input; assay.py and metrics.py: ruff formatting --- scarf/assay.py | 18 +++-- scarf/datastore/datastore.py | 132 ++++++++--------------------------- scarf/metrics.py | 14 ++-- 3 files changed, 45 insertions(+), 119 deletions(-) diff --git a/scarf/assay.py b/scarf/assay.py index 83cf143..10c5aca 100644 --- a/scarf/assay.py +++ b/scarf/assay.py @@ -22,8 +22,6 @@ from .metadata import MetaData from .utils import controlled_compute, logger, show_dask_progress -zarrGroup = z_hierarchy.Group - __all__ = ["Assay", "RNAassay", "ATACassay", "ADTassay"] @@ -102,7 +100,7 @@ class Assay: for later KNN graph construction. Args: - z (zarrGroup): Zarr hierarchy where raw data is located + z (z_hierarchy.Group): Zarr hierarchy where raw data is located name (str): A label/name for assay. cell_data: Metadata class object for the cell attributes. nthreads: number for threads to use for dask parallel computations @@ -122,7 +120,7 @@ class Assay: def __init__( self, - z: zarrGroup, + z: z_hierarchy.Group, workspace: Union[str, None], name: str, # FIXME change to assay_name cell_data: MetaData, @@ -757,7 +755,7 @@ class RNAassay(Assay): normalization of scRNA-Seq data. Args: - z (zarrGroup): Zarr hierarchy where raw data is located + z (z_hierarchy.Group): Zarr hierarchy where raw data is located name (str): A label/name for assay. cell_data: Metadata class object for the cell attributes. **kwargs: kwargs to be passed to the Assay class @@ -769,7 +767,7 @@ class RNAassay(Assay): It is set to None until normed method is called. """ - def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs): + def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs): super().__init__(z=z, name=name, cell_data=cell_data, **kwargs) self.normMethod = norm_lib_size if "size_factor" in self.attrs: @@ -1076,12 +1074,12 @@ class ATACassay(Assay): """This subclass of Assay is designed for feature selection and normalization of scATAC-Seq data.""" - def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs): + def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs): """This Assay subclass is designed for feature selection and normalization of scATAC-Seq data. Args: - z (zarrGroup): Zarr hierarchy where raw data is located + z (z_hierarchy.Group): Zarr hierarchy where raw data is located name (str): A label/name for assay. cell_data: Metadata class object for the cell attributes. **kwargs: @@ -1208,7 +1206,7 @@ class ADTassay(Assay): (feature-barcodes library) data from CITE-Seq experiments. Args: - z (zarrGroup): Zarr hierarchy where raw data is located + z (z_hierarchy.Group): Zarr hierarchy where raw data is located name (str): A label/name for assay. cell_data: Metadata class object for the cell attributes. **kwargs: @@ -1217,7 +1215,7 @@ class ADTassay(Assay): normMethod: Pointer to the function to be used for normalization of the raw data """ - def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs): + def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs): """This subclass of Assay is designed for normalization of ADT/HTO (feature-barcodes library) data from CITE-Seq experiments.""" super().__init__(z=z, name=name, cell_data=cell_data, **kwargs) diff --git a/scarf/datastore/datastore.py b/scarf/datastore/datastore.py index 7970cf0..bbefeb8 100644 --- a/scarf/datastore/datastore.py +++ b/scarf/datastore/datastore.py @@ -278,7 +278,7 @@ def mark_hvgs( if cell_key is None: cell_key = "I" assay = self._get_assay(from_assay) - if type(assay) != RNAassay: + if type(assay) != RNAassay: # noqa: E721 raise TypeError( f"ERROR: This method of feature selection can only be applied to RNAassay type of assay. " f"The provided assay is {type(assay)} type" @@ -334,7 +334,7 @@ def mark_prevalent_peaks( if cell_key is None: cell_key = "I" assay = self._get_assay(from_assay) - if type(assay) != ATACassay: + if type(assay) != ATACassay: # noqa: E721 raise TypeError( f"ERROR: This method of feature selection can only be applied to ATACassay type of assay. " f"The provided assay is {type(assay)} type" @@ -1286,7 +1286,7 @@ def plot_cells_dists( pass if cols is not None: - if type(cols) != list: + if type(cols) != list: # noqa: E721 raise ValueError("ERROR: 'cols' argument must be of type list") plot_cols = [] for i in cols: @@ -2052,17 +2052,7 @@ def metric_lisi( label_colnames: Iterable[str], use_latest_knn: bool = True, from_assay: Optional[str] = None, - cell_key: Optional[str] = None, - feat_key: Optional[str] = None, - dims: Optional[str] = None, - reduction_method: Optional[str] = None, - pca_cell_key: Optional[str] = None, - ann_metric: Optional[str] = None, - ann_efc: Optional[int] = None, - ann_ef: Optional[int] = None, - ann_m: Optional[int] = None, - rand_state: Optional[int] = 4466, - k: Optional[int] = None, + knn_loc: Optional[str] = None, save_result: bool = False, return_lisi: bool = True, ) -> Optional[List[Tuple[str, np.ndarray]]]: @@ -2075,17 +2065,7 @@ def metric_lisi( label_colnames: Column names from cell metadata containing population labels use_latest_knn: Whether to use the most recent KNN graph (default: True) from_assay: Name of assay to use if not using latest KNN - cell_key: Cell filtering key for normalization - feat_key: Feature selection key for normalization - dims: Number of dimensions used for reduction - reduction_method: Name of dimensionality reduction method - pca_cell_key: Cell key used for PCA - ann_metric: Metric used for approximate nearest neighbors - ann_efc: Construction time/accuracy trade-off for ANN index - ann_ef: Query time/accuracy trade-off for ANN index - ann_m: Max number of connections in ANN graph - rand_state: Random seed for reproducibility (default: 4466) - k: Number of nearest neighbors + knn_loc: Location of KNN graph if not using latest (default: None) save_result: Whether to save LISI scores to cell metadata (default: True) return_lisi: Whether to return LISI scores (default: False) @@ -2105,43 +2085,24 @@ def metric_lisi( Higher scores indicate more mixing between different labels. """ - if use_latest_knn: + if use_latest_knn and knn_loc is None: knn_loc = self._get_latest_knn_loc(from_assay) cell_key = self.zw[self._load_default_assay()].attrs["latest_cell_key"] logger.info(f"Using the latest knn graph at location: {knn_loc}") - else: - if None in [ - from_assay, - cell_key, - feat_key, - dims, - k, - reduction_method, - pca_cell_key, - ann_metric, - ann_efc, - ann_ef, - ann_m, - rand_state, - ]: - raise ValueError( - "Please provide values for all the parameters: from_assay, cell_key, feat_key, dims, k, reduction_method, pca_cell_key, ann_metric, ann_efc, ann_ef, ann_m, rand_state" - ) - normed_loc = f"{from_assay}/normed__{cell_key}__{feat_key}" - reduction_loc = ( - f"{normed_loc}/reduction__{reduction_method}__{dims}__{pca_cell_key}" - ) - ann_loc = f"{reduction_loc}/ann__{ann_metric}__{ann_efc}__{ann_ef}__{ann_m}__{rand_state}" - knn_loc = f"{ann_loc}/knn__{k}" + else: + if knn_loc is None: + raise ValueError("Please provide values for the KNN graph location.") if knn_loc not in self.zw: raise ValueError(f"Could not find the knn graph at location: {knn_loc}") + logger.info(f"Using the knn graph at location: {knn_loc}") knn = self.zw[knn_loc] distances = knn["distances"] indices = knn["indices"] + try: metadata = self.cells.to_pandas_dataframe( columns=label_colnames + [cell_key] @@ -2171,17 +2132,7 @@ def metric_silhouette( use_latest_knn: bool = True, res_label: str = "leiden_cluster", from_assay: Optional[str] = None, - cell_key: Optional[str] = None, - feat_key: Optional[str] = None, - dims: Optional[str] = None, - reduction_method: Optional[str] = None, - pca_cell_key: Optional[str] = None, - ann_metric: Optional[str] = None, - ann_efc: Optional[int] = None, - ann_ef: Optional[int] = None, - ann_m: Optional[int] = None, - rand_state: Optional[int] = 4466, - k: Optional[int] = None, + knn_loc: Optional[str] = None, ) -> Optional[np.ndarray]: """Calculate modified silhouette scores for evaluating cluster separation. @@ -2191,18 +2142,8 @@ def metric_silhouette( Args: use_latest_knn: Whether to use most recent KNN graph (default: True) res_label: Column name containing cluster labels (default: "leiden_cluster") - from_assay: Name of assay to use if not using latest KNN - cell_key: Cell filtering key for normalization - feat_key: Feature selection key for normalization - dims: Number of dimensions used for reduction - reduction_method: Name of dimensionality reduction method - pca_cell_key: Cell key used for PCA - ann_metric: Metric used for approximate nearest neighbors - ann_efc: Construction time/accuracy trade-off for ANN index - ann_ef: Query time/accuracy trade-off for ANN index - ann_m: Max number of connections in ANN graph - rand_state: Random seed for reproducibility (default: 4466) - k: Number of nearest neighbors + from_assay: Name of assay to use if not using latest KNN (default: None) + knn_loc: Location of KNN graph if not using latest (default: None) Returns: numpy array of silhouette scores for each cluster, or None if computation fails @@ -2220,46 +2161,34 @@ def metric_silhouette( NaN values indicate clusters that couldn't be scored due to size constraints. """ - if use_latest_knn: - knn_loc = self._get_latest_knn_loc(from_assay) - from_assay = self._load_default_assay() - logger.info(f"Using the latest knn graph at location: {knn_loc}") + def compute_graph_feats(knn_loc: str): k = knn_loc.rsplit("/", 1)[-1].split("__")[-1] dims = knn_loc.rsplit("/", 2)[0].split("__")[-2] feat_key = knn_loc.split("/")[1].split("__")[-1] + return k, dims, feat_key - else: - if None in [ - from_assay, - cell_key, - feat_key, - dims, - k, - reduction_method, - pca_cell_key, - ann_metric, - ann_efc, - ann_ef, - ann_m, - rand_state, - ]: - raise ValueError( - "Please provide values for all the parameters: from_assay, cell_key, feat_key, dims, k, reduction_method, pca_cell_key, ann_metric, ann_efc, ann_ef, ann_m, rand_state" - ) - normed_loc = f"{from_assay}/normed__{cell_key}__{feat_key}" - reduction_loc = ( - f"{normed_loc}/reduction__{reduction_method}__{dims}__{pca_cell_key}" + if from_assay is None: + from_assay = self._load_default_assay() + + if use_latest_knn and knn_loc is None: + knn_loc = self._get_latest_knn_loc(from_assay) + k, dims, feat_key = compute_graph_feats(knn_loc) + logger.info( + f"Using the latest knn graph at location: {knn_loc} for assay: {from_assay}" ) - ann_loc = f"{reduction_loc}/ann__{ann_metric}__{ann_efc}__{ann_ef}__{ann_m}__{rand_state}" - knn_loc = f"{ann_loc}/knn__{k}" + else: + if knn_loc is None: + raise ValueError("Please provide values for the KNN graph location.") if knn_loc not in self.zw: raise ValueError(f"Could not find the knn graph at location: {knn_loc}") + k, dims, feat_key, from_assay = compute_graph_feats(knn_loc) logger.info(f"Using the knn graph at location: {knn_loc}") from ..metrics import knn_to_csr_matrix, silhouette_scoring isHarmonized = self.zw[knn_loc.rsplit("/", 1)[0]].attrs["isHarmonized"] + batches = None if isHarmonized: batches = self.zw[knn_loc.rsplit("/", 2)[0] + "/harmonizedData"].attrs[ @@ -2274,10 +2203,9 @@ def metric_silhouette( harmonize=isHarmonized, batch_columns=batches, ) - graph = knn_to_csr_matrix(self.z[knn_loc].indices, self.z[knn_loc].distances) + graph = knn_to_csr_matrix(self.z[knn_loc].indices, self.z[knn_loc].distances) hvg_data = self.z[knn_loc.rsplit("/", 3)[0] + "/data"] - scores = silhouette_scoring( self, ann_obj, graph, hvg_data, from_assay, res_label ) diff --git a/scarf/metrics.py b/scarf/metrics.py index 655491e..fe91a20 100644 --- a/scarf/metrics.py +++ b/scarf/metrics.py @@ -195,9 +195,9 @@ def calculate_weighted_cluster_similarity( """ unique_cluster_ids = np.unique(cluster_labels) expected_cluster_ids = np.arange(0, len(unique_cluster_ids)) - assert np.array_equal( - unique_cluster_ids, expected_cluster_ids - ), "Cluster labels must be contiguous integers starting at 1" + assert np.array_equal(unique_cluster_ids, expected_cluster_ids), ( + "Cluster labels must be contiguous integers starting at 1" + ) num_clusters = len(unique_cluster_ids) inter_cluster_weights = np.zeros((num_clusters, num_clusters)) @@ -212,7 +212,7 @@ def calculate_weighted_cluster_similarity( ): inter_cluster_weights[cluster_id, neighbor_cluster] += edge_weight - assert inter_cluster_weights.sum() == knn_graph.data.sum() + # assert inter_cluster_weights.sum() == knn_graph.data.sum() # Ensure symmetry inter_cluster_weights = (inter_cluster_weights + inter_cluster_weights.T) / 2 @@ -263,9 +263,9 @@ def calculate_top_k_neighbor_distances( AssertionError: If matrices don't have the same number of features """ # Check if the matrices have the same number of features (d) - assert ( - matrix_a.shape[1] == matrix_b.shape[1] - ), "Matrices must have the same number of features" + assert matrix_a.shape[1] == matrix_b.shape[1], ( + "Matrices must have the same number of features" + ) # Ensure k is not larger than the number of points in matrix_b k = min(k, matrix_b.shape[0])