From 89a5149603712cb86b0fcfd143c8ba51fced9048 Mon Sep 17 00:00:00 2001
From: Gautam Ahuja <goutamahuja8387@gmail.com>
Date: Thu, 16 Jan 2025 23:39:08 +0530
Subject: [PATCH] Addition of Metric Calculation Methods (#131)

* Added LiSi evaulation metric and helper methods in datastore

* datastore.py: added metrics methods for lisi, silhouette, and integration (adjusted rand score, normalized mutual information score). Uses the latest knn location when calculating default. Provide all parameter otherwise.

metrics.py: function for computing all scores.

graph_datastore.py: rename functions

* All:
	- Added DocString and typing
datastore.py:
	- lisi: filtered metadata as per 'I'
	- doc strings and typing
metrics.py:
	- formatted & typing
tests:
	- Added test for metrics

* comment cleanup

* Added progress bar on Lisi

* datastore.py: updated metric_lisi, metric_silhouette, and metric_integration to use latest KNN location and option to provide KNN location as input; assay.py and metrics.py: ruff formatting
---
 scarf/assay.py                     |  18 +-
 scarf/datastore/datastore.py       | 242 +++++++++++++--
 scarf/datastore/graph_datastore.py |  35 ++-
 scarf/metrics.py                   | 465 +++++++++++++++++++++++++++++
 scarf/tests/test_metrics.py        |  49 +++
 5 files changed, 773 insertions(+), 36 deletions(-)
 create mode 100644 scarf/metrics.py
 create mode 100644 scarf/tests/test_metrics.py

diff --git a/scarf/assay.py b/scarf/assay.py
index 83cf143..10c5aca 100644
--- a/scarf/assay.py
+++ b/scarf/assay.py
@@ -22,8 +22,6 @@
 from .metadata import MetaData
 from .utils import controlled_compute, logger, show_dask_progress
 
-zarrGroup = z_hierarchy.Group
-
 __all__ = ["Assay", "RNAassay", "ATACassay", "ADTassay"]
 
 
@@ -102,7 +100,7 @@ class Assay:
     for later KNN graph construction.
 
     Args:
-        z (zarrGroup): Zarr hierarchy where raw data is located
+        z (z_hierarchy.Group): Zarr hierarchy where raw data is located
         name (str): A label/name for assay.
         cell_data: Metadata class object for the cell attributes.
         nthreads: number for threads to use for dask parallel computations
@@ -122,7 +120,7 @@ class Assay:
 
     def __init__(
         self,
-        z: zarrGroup,
+        z: z_hierarchy.Group,
         workspace: Union[str, None],
         name: str,  # FIXME change to assay_name
         cell_data: MetaData,
@@ -757,7 +755,7 @@ class RNAassay(Assay):
     normalization of scRNA-Seq data.
 
     Args:
-        z (zarrGroup): Zarr hierarchy where raw data is located
+        z (z_hierarchy.Group): Zarr hierarchy where raw data is located
         name (str): A label/name for assay.
         cell_data: Metadata class object for the cell attributes.
         **kwargs: kwargs to be passed to the Assay class
@@ -769,7 +767,7 @@ class RNAassay(Assay):
                 It is set to None until normed method is called.
     """
 
-    def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs):
+    def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs):
         super().__init__(z=z, name=name, cell_data=cell_data, **kwargs)
         self.normMethod = norm_lib_size
         if "size_factor" in self.attrs:
@@ -1076,12 +1074,12 @@ class ATACassay(Assay):
     """This subclass of Assay is designed for feature selection and
     normalization of scATAC-Seq data."""
 
-    def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs):
+    def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs):
         """This Assay subclass is designed for feature selection and
         normalization of scATAC-Seq data.
 
         Args:
-            z (zarrGroup): Zarr hierarchy where raw data is located
+            z (z_hierarchy.Group): Zarr hierarchy where raw data is located
             name (str): A label/name for assay.
             cell_data: Metadata class object for the cell attributes.
             **kwargs:
@@ -1208,7 +1206,7 @@ class ADTassay(Assay):
     (feature-barcodes library) data from CITE-Seq experiments.
 
     Args:
-        z (zarrGroup): Zarr hierarchy where raw data is located
+        z (z_hierarchy.Group): Zarr hierarchy where raw data is located
         name (str): A label/name for assay.
         cell_data: Metadata class object for the cell attributes.
         **kwargs:
@@ -1217,7 +1215,7 @@ class ADTassay(Assay):
         normMethod: Pointer to the function to be used for normalization of the raw data
     """
 
-    def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs):
+    def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs):
         """This subclass of Assay is designed for normalization of ADT/HTO
         (feature-barcodes library) data from CITE-Seq experiments."""
         super().__init__(z=z, name=name, cell_data=cell_data, **kwargs)
diff --git a/scarf/datastore/datastore.py b/scarf/datastore/datastore.py
index b803ae5..bbefeb8 100644
--- a/scarf/datastore/datastore.py
+++ b/scarf/datastore/datastore.py
@@ -1,15 +1,15 @@
-from typing import Iterable, Optional, Union, List, Literal, Tuple
+from typing import Iterable, List, Literal, Optional, Tuple, Union
 
 import numpy as np
 import pandas as pd
 from dask import array as daskarr
 from loguru import logger
 
-from .mapping_datastore import MappingDatastore
-from ..assay import Assay, RNAassay, ATACassay
+from ..assay import Assay, ATACassay, RNAassay
 from ..feat_utils import hto_demux
-from ..utils import tqdmbar, controlled_compute, ZARRLOC
-from ..writers import create_zarr_obj_array, create_zarr_dataset
+from ..utils import ZARRLOC, controlled_compute, tqdmbar
+from ..writers import create_zarr_dataset, create_zarr_obj_array
+from .mapping_datastore import MappingDatastore
 
 __all__ = ["DataStore"]
 
@@ -75,10 +75,7 @@ def __init__(
             synchronizer=synchronizer,
         )
 
-    def get_assay(
-        self,
-        assay_name: str
-    ) -> Assay:
+    def get_assay(self, assay_name: str) -> Assay:
         """Returns the assay object for the given assay name.
 
         Args:
@@ -91,8 +88,7 @@ def get_assay(
             raise ValueError(f"ERROR: Assay {assay_name} not found in the Zarr file")
         else:
             return getattr(self, assay_name)
-        
-    
+
     def filter_cells(
         self,
         attrs: Iterable[str],
@@ -282,7 +278,7 @@ def mark_hvgs(
         if cell_key is None:
             cell_key = "I"
         assay = self._get_assay(from_assay)
-        if type(assay) != RNAassay:
+        if type(assay) != RNAassay:  # noqa: E721
             raise TypeError(
                 f"ERROR: This method of feature selection can only be applied to RNAassay type of assay. "
                 f"The provided assay is {type(assay)} type"
@@ -338,7 +334,7 @@ def mark_prevalent_peaks(
         if cell_key is None:
             cell_key = "I"
         assay = self._get_assay(from_assay)
-        if type(assay) != ATACassay:
+        if type(assay) != ATACassay:  # noqa: E721
             raise TypeError(
                 f"ERROR: This method of feature selection can only be applied to ATACassay type of assay. "
                 f"The provided assay is {type(assay)} type"
@@ -625,8 +621,8 @@ def get_markers(
             cell_key = "I"
         if group_key is None:
             raise ValueError(
-                f"ERROR: Please provide a value for group_key. "
-                f"This should be same as used for `run_marker_search`"
+                "ERROR: Please provide a value for group_key. "
+                "This should be same as used for `run_marker_search`"
             )
         assay = self._get_assay(from_assay)
         try:
@@ -706,8 +702,8 @@ def export_markers_to_csv(
         # Not testing the values of from_assay and cell_key because they will be tested in `get_markers`
         if group_key is None:
             raise ValueError(
-                f"ERROR: Please provide a value for group_key. "
-                f"This should be same as used for `run_marker_search`"
+                "ERROR: Please provide a value for group_key. "
+                "This should be same as used for `run_marker_search`"
             )
         if csv_filename is None:
             raise ValueError(
@@ -1290,7 +1286,7 @@ def plot_cells_dists(
             pass
 
         if cols is not None:
-            if type(cols) != list:
+            if type(cols) != list:  # noqa: E721
                 raise ValueError("ERROR: 'cols' argument must be of type list")
             plot_cols = []
             for i in cols:
@@ -1496,7 +1492,7 @@ def plot_layout(
         # TODO: add support for providing a list of subselections, from_assay and cell_keys
         # TODO: add support for different kinds of point markers
 
-        from ..plots import shade_scatter, plot_scatter
+        from ..plots import plot_scatter, shade_scatter
 
         if from_assay is None:
             from_assay = self._defaultAssay
@@ -1728,9 +1724,10 @@ def plot_cluster_tree(
             None
         """
 
-        from ..plots import plot_cluster_hierarchy
+        from networkx import DiGraph, to_pandas_edgelist
+
         from ..dendrogram import CoalesceTree, make_digraph
-        from networkx import to_pandas_edgelist, DiGraph
+        from ..plots import plot_cluster_hierarchy
 
         from_assay, cell_key, feat_key = self._get_latest_keys(
             from_assay, cell_key, feat_key
@@ -2011,8 +2008,8 @@ def plot_pseudotime_heatmap(
         else:
             if hashes != assay.z[location].attrs["hashes"]:
                 raise ValueError(
-                    f"ERROR: The values under one or more of these columns: `cell_key`, `feat_key` or/and "
-                    f"`pseudotime_key have been updated after running `run_pseudotime_aggregation`"
+                    "ERROR: The values under one or more of these columns: `cell_key`, `feat_key` or/and "
+                    "`pseudotime_key have been updated after running `run_pseudotime_aggregation`"
                 )
 
         da = daskarr.from_zarr(assay.z[location + "/data"], inline_array=True)
@@ -2049,3 +2046,202 @@ def plot_pseudotime_heatmap(
             save_dpi=save_dpi,
             show_fig=show_fig,
         )
+
+    def metric_lisi(
+        self,
+        label_colnames: Iterable[str],
+        use_latest_knn: bool = True,
+        from_assay: Optional[str] = None,
+        knn_loc: Optional[str] = None,
+        save_result: bool = False,
+        return_lisi: bool = True,
+    ) -> Optional[List[Tuple[str, np.ndarray]]]:
+        """Calculate Local Inverse Simpson Index (LISI) scores for cell populations.
+
+        LISI measures how well mixed different cell populations are in the local neighborhood
+        of each cell. Higher scores indicate better mixing of different populations.
+
+        Args:
+            label_colnames: Column names from cell metadata containing population labels
+            use_latest_knn: Whether to use the most recent KNN graph (default: True)
+            from_assay: Name of assay to use if not using latest KNN
+            knn_loc: Location of KNN graph if not using latest (default: None)
+            save_result: Whether to save LISI scores to cell metadata (default: True)
+            return_lisi: Whether to return LISI scores (default: False)
+
+        Returns:
+            If return_lisi is True, returns list of tuples containing:
+            - Label column name
+            - numpy array of LISI scores for that label
+            If return_lisi is False, returns None
+
+        Raises:
+            ValueError: If using custom KNN graph but required parameters are missing
+            KeyError: If label columns not found in cell metadata
+
+        Notes:
+            LISI scores are computed for each label column separately.
+            Scores near 1 indicate cells grouped with similar labels.
+            Higher scores indicate more mixing between different labels.
+        """
+
+        if use_latest_knn and knn_loc is None:
+            knn_loc = self._get_latest_knn_loc(from_assay)
+            cell_key = self.zw[self._load_default_assay()].attrs["latest_cell_key"]
+            logger.info(f"Using the latest knn graph at location: {knn_loc}")
+
+        else:
+            if knn_loc is None:
+                raise ValueError("Please provide values for the KNN graph location.")
+            if knn_loc not in self.zw:
+                raise ValueError(f"Could not find the knn graph at location: {knn_loc}")
+
+            logger.info(f"Using the knn graph at location: {knn_loc}")
+
+        knn = self.zw[knn_loc]
+
+        distances = knn["distances"]
+        indices = knn["indices"]
+
+        try:
+            metadata = self.cells.to_pandas_dataframe(
+                columns=label_colnames + [cell_key]
+            )
+            metadata = metadata[metadata[cell_key]]
+        except KeyError:
+            raise KeyError(
+                f"Could not find the column(s) {label_colnames} in the cell metadata table."
+            )
+
+        from ..metrics import compute_lisi
+
+        lisi_scores = compute_lisi(distances, indices, metadata, label_colnames)
+        # lisi_scores Shape -> (n_cells, n_labels)
+        if save_result:
+            for col, vals in zip(label_colnames, lisi_scores.T):
+                col_name = f"lisi__{col}__{knn_loc.split('/')[-1]}"
+                self.cells.insert(column_name=col_name, values=vals, overwrite=True)
+
+        if return_lisi:
+            return list(zip(label_colnames, lisi_scores.T))
+        else:
+            return None
+
+    def metric_silhouette(
+        self,
+        use_latest_knn: bool = True,
+        res_label: str = "leiden_cluster",
+        from_assay: Optional[str] = None,
+        knn_loc: Optional[str] = None,
+    ) -> Optional[np.ndarray]:
+        """Calculate modified silhouette scores for evaluating cluster separation.
+
+        This implements a graph-based silhouette score that measures how similar cells
+        are to their own cluster compared to the nearest neighboring cluster.
+
+        Args:
+            use_latest_knn: Whether to use most recent KNN graph (default: True)
+            res_label: Column name containing cluster labels (default: "leiden_cluster")
+            from_assay: Name of assay to use if not using latest KNN (default: None)
+            knn_loc: Location of KNN graph if not using latest (default: None)
+
+        Returns:
+            numpy array of silhouette scores for each cluster, or None if computation fails
+
+        Raises:
+            ValueError: If using custom KNN graph but required parameters are missing
+
+        Notes:
+            Scores range from -1 to 1:
+            - Near 1: Cluster is well-separated from neighboring clusters
+            - Near 0: Cluster overlaps with neighboring clusters
+            - Near -1: Cluster may be incorrectly assigned
+
+            Implementation uses sampling for efficiency with large datasets.
+            NaN values indicate clusters that couldn't be scored due to size constraints.
+        """
+
+        def compute_graph_feats(knn_loc: str):
+            k = knn_loc.rsplit("/", 1)[-1].split("__")[-1]
+            dims = knn_loc.rsplit("/", 2)[0].split("__")[-2]
+            feat_key = knn_loc.split("/")[1].split("__")[-1]
+            return k, dims, feat_key
+
+        if from_assay is None:
+            from_assay = self._load_default_assay()
+
+        if use_latest_knn and knn_loc is None:
+            knn_loc = self._get_latest_knn_loc(from_assay)
+            k, dims, feat_key = compute_graph_feats(knn_loc)
+            logger.info(
+                f"Using the latest knn graph at location: {knn_loc} for assay: {from_assay}"
+            )
+
+        else:
+            if knn_loc is None:
+                raise ValueError("Please provide values for the KNN graph location.")
+            if knn_loc not in self.zw:
+                raise ValueError(f"Could not find the knn graph at location: {knn_loc}")
+            k, dims, feat_key, from_assay = compute_graph_feats(knn_loc)
+            logger.info(f"Using the knn graph at location: {knn_loc}")
+
+        from ..metrics import knn_to_csr_matrix, silhouette_scoring
+
+        isHarmonized = self.zw[knn_loc.rsplit("/", 1)[0]].attrs["isHarmonized"]
+
+        batches = None
+        if isHarmonized:
+            batches = self.zw[knn_loc.rsplit("/", 2)[0] + "/harmonizedData"].attrs[
+                "batches"
+            ]
+
+        ann_obj = self.make_graph(
+            feat_key=feat_key,
+            dims=dims,
+            k=k,
+            return_ann_object=True,
+            harmonize=isHarmonized,
+            batch_columns=batches,
+        )
+
+        graph = knn_to_csr_matrix(self.z[knn_loc].indices, self.z[knn_loc].distances)
+        hvg_data = self.z[knn_loc.rsplit("/", 3)[0] + "/data"]
+        scores = silhouette_scoring(
+            self, ann_obj, graph, hvg_data, from_assay, res_label
+        )
+        return scores
+
+    def metric_integration(
+        self, batch_labels: List[str], metric: Literal["ari", "nmi"] = "ari"
+    ) -> Optional[float]:
+        """Calculate integration score between different batch labels.
+
+        Measures how well aligned different batches are after integration by comparing
+        their cluster assignments using standard clustering metrics.
+
+        Args:
+            batch_labels: List of column names containing batch labels to compare
+            metric: Metric to use for comparison (default: "ari")
+                - "ari": Adjusted Rand Index
+                - "nmi": Normalized Mutual Information
+
+        Returns:
+            Integration score between 0 and 1, or None if metric is not recognized
+            Higher scores indicate better alignment between batches
+
+        Notes:
+            ARI and NMI measure the agreement between different batch labelings:
+            - Score near 1: Batches are well integrated
+            - Score near 0: Batches show poor integration
+
+            ARI is adjusted for chance and generally more stringent than NMI.
+        """
+        from ..metrics import integration_score
+
+        batch_labels_vals = []
+        for batch in batch_labels:
+            vals = np.array(self.cells.fetch_all(batch))
+            batch_labels_vals.append(vals)
+
+        scores = integration_score(batch_labels_vals, metric)
+        return scores
diff --git a/scarf/datastore/graph_datastore.py b/scarf/datastore/graph_datastore.py
index 5235f94..7ba0b1d 100644
--- a/scarf/datastore/graph_datastore.py
+++ b/scarf/datastore/graph_datastore.py
@@ -1,16 +1,16 @@
 import os
-from typing import Tuple, Optional, Union, List, Callable
+from typing import Callable, List, Optional, Tuple, Union
 
 import numpy as np
 import pandas as pd
 from dask.array import from_zarr  # type: ignore
 from loguru import logger
-from scipy.sparse import csr_matrix, coo_matrix
+from scipy.sparse import coo_matrix, csr_matrix
 
-from .base_datastore import BaseDataStore
 from ..assay import Assay
 from ..utils import clean_array, show_dask_progress, system_call, tqdmbar
 from ..writers import create_zarr_dataset
+from .base_datastore import BaseDataStore
 
 
 class GraphDataStore(BaseDataStore):
@@ -396,6 +396,34 @@ def _get_latest_graph_loc(
         knn_loc = self.zw[ann_loc].attrs["latest_knn"]
         return self.zw[knn_loc].attrs["latest_graph"]
 
+    def _get_latest_knn_loc(self, from_assay: str = None) -> str:
+        """Convenience function to identify location of the latest KNN graph in
+        the Zarr hierarchy.
+
+        Args:
+            from_assay: Name of the assay.
+
+        Returns:
+            Path of KNN graph in the Zarr hierarchy
+        """
+        if from_assay is None:
+            logger.info("Using default assay for KNN graph.")
+            from_assay = self._load_default_assay()
+
+        if from_assay not in self.assay_names:
+            raise ValueError(f"ERROR: Assay {from_assay} does not exist")
+
+        latest_cell_key = self.zw[from_assay].attrs["latest_cell_key"]
+        latest_feat_key = self.zw[from_assay].attrs["latest_feat_key"]
+        normed_loc = f"{from_assay}/normed__{latest_cell_key}__{latest_feat_key}"
+        reduction_loc = self.zw[normed_loc].attrs["latest_reduction"]
+        if "reduction" not in self.zw[reduction_loc]:
+            raise ValueError(f"ERROR: PCA Reduction not found in {reduction_loc}")
+        latest_ann = self.zw[reduction_loc].attrs["latest_ann"]
+        ann_loc = self.zw[latest_ann]
+        latest_knn = ann_loc.attrs["latest_knn"]
+        return latest_knn
+
     def _get_ini_embed(
         self, from_assay: str, cell_key: str, feat_key: str, n_comps: int
     ) -> np.ndarray:
@@ -414,6 +442,7 @@ def _get_ini_embed(
             Matrix with n_comps dimensions representing initial embedding of cells.
         """
         from sklearn.decomposition import PCA
+
         from ..utils import rescale_array
 
         normed_loc = f"{from_assay}/normed__{cell_key}__{feat_key}"
diff --git a/scarf/metrics.py b/scarf/metrics.py
new file mode 100644
index 0000000..fe91a20
--- /dev/null
+++ b/scarf/metrics.py
@@ -0,0 +1,465 @@
+"""
+Methods and classes for evluation
+"""
+
+from typing import Iterable, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import pandas as pd
+from scipy.sparse import csr_matrix
+from zarr.core import Array as zarrArrayType
+
+from .ann import AnnStream
+from .datastore.datastore import DataStore
+from .utils import (
+    logger,
+    tqdmbar,
+)
+
+
+# LISI - The Local Inverse Simpson Index
+def compute_lisi(
+    distances: zarrArrayType,
+    indices: zarrArrayType,
+    metadata: pd.DataFrame,
+    label_colnames: Iterable[str],
+    perplexity: float = 30,
+) -> np.ndarray:
+    """Compute the Local Inverse Simpson Index (LISI) for each column in metadata.
+
+    LISI measures how well mixed different groups of cells are in the neighborhood of each cell.
+    Higher values indicate better mixing of different groups.
+
+    Args:
+        distances: Pre-computed distances between cells, stored in zarr array format
+        indices: Pre-computed nearest neighbor indices, stored in zarr array format
+        metadata: DataFrame containing categorical labels for each cell
+        label_colnames: Column names in metadata to compute LISI for
+        perplexity: Parameter controlling the effective number of neighbors (default: 30)
+
+    Returns:
+        np.ndarray: Matrix of LISI scores with shape (n_cells, n_labels)
+        Each column corresponds to LISI scores for one label column in metadata
+
+    Example:
+        For metadata with a 'batch' column having 3 categories:
+        - LISI ≈ 3: Cell has neighbors from all 3 batches (well mixed)
+        - LISI ≈ 1: Cell has neighbors from only 1 batch (poorly mixed)
+
+    References:
+        Korsunsky et al. 2019 doi: 10.1038/s41592-019-0619-0
+    """
+
+    n_cells = metadata.shape[0]
+    n_labels = len(label_colnames)
+    # Don't count yourself
+    indices = indices[:, 1:]
+    distances = distances[:, 1:]
+    lisi_df = np.zeros((n_cells, n_labels))
+    for i, label in enumerate(label_colnames):
+        logger.info(f"Computing LISI for {label}")
+        labels = pd.Categorical(metadata[label])
+        n_categories = len(labels.categories)
+        simpson = compute_simpson(
+            distances.T, indices.T, labels, n_categories, perplexity
+        )
+        lisi_df[:, i] = 1 / simpson
+    return lisi_df
+
+
+def compute_simpson(
+    distances: np.ndarray,
+    indices: np.ndarray,
+    labels: pd.Categorical,
+    perplexity: float,
+    tol: float = 1e-5,
+) -> np.ndarray:
+    """Compute Simpson's diversity index with Gaussian kernel weighting.
+
+    This function implements the core calculation for LISI, computing a diversity score
+    based on the distribution of categories in each cell's neighborhood.
+
+    Args:
+        distances: Distance matrix between points, shape (n_neighbors, n_points)
+        indices: Index matrix for nearest neighbors, shape (n_neighbors, n_points)
+        labels: Categorical labels for each point
+        perplexity: Target perplexity for Gaussian kernel
+        tol: Convergence tolerance for perplexity calibration (default: 1e-5)
+
+    Returns:
+        np.ndarray: Array of Simpson's diversity indices, one per point
+    """
+    n = distances.shape[1]
+    P = np.zeros(distances.shape[0])
+    simpson = np.zeros(n)
+    logU = np.log(perplexity)
+    # Loop through each cell.
+    for i in tqdmbar(range(n), desc="Computing Simpson's Diversity Index"):
+        beta = 1
+        betamin = -np.inf
+        betamax = np.inf
+        # Compute Hdiff
+        P = np.exp(-distances[:, i] * beta)
+        P_sum = np.sum(P)
+        if P_sum == 0:
+            H = 0
+            P = np.zeros(distances.shape[0])
+        else:
+            H = np.log(P_sum) + beta * np.sum(distances[:, i] * P) / P_sum
+            P = P / P_sum
+        Hdiff = H - logU
+        n_tries = 50
+        for t in range(n_tries):
+            # Stop when we reach the tolerance
+            if abs(Hdiff) < tol:
+                break
+            # Update beta
+            if Hdiff > 0:
+                betamin = beta
+                if not np.isfinite(betamax):
+                    beta *= 2
+                else:
+                    beta = (beta + betamax) / 2
+            else:
+                betamax = beta
+                if not np.isfinite(betamin):
+                    beta /= 2
+                else:
+                    beta = (beta + betamin) / 2
+            # Compute Hdiff
+            P = np.exp(-distances[:, i] * beta)
+            P_sum = np.sum(P)
+            if P_sum == 0:
+                H = 0
+                P = np.zeros(distances.shape[0])
+            else:
+                H = np.log(P_sum) + beta * np.sum(distances[:, i] * P) / P_sum
+                P = P / P_sum
+            Hdiff = H - logU
+        # distancesefault value
+        if H == 0:
+            simpson[i] = -1
+        # Simpson's index
+        for label_category in labels.categories:
+            ix = indices[:, i]
+            q = labels[ix] == label_category
+            if np.any(q):
+                P_sum = np.sum(P[q])
+                simpson[i] += P_sum * P_sum
+    return simpson
+
+
+# SILHOUETTE SCORE - The Silhouette Score
+def knn_to_csr_matrix(
+    neighbor_indices: np.ndarray, neighbor_distances: np.ndarray
+) -> csr_matrix:
+    """Convert k-nearest neighbors data to a Compressed Sparse Row (CSR) matrix.
+
+    Creates a sparse adjacency matrix representation of the k-nearest neighbors graph
+    where edge weights are the distances between points.
+
+    Args:
+        neighbor_indices: Indices matrix from k-nearest neighbors, shape (n_samples, k)
+        neighbor_distances: Distances matrix from k-nearest neighbors, shape (n_samples, k)
+
+    Returns:
+        scipy.sparse.csr_matrix: Sparse adjacency matrix of shape (n_samples, n_samples)
+        where non-zero entries represent distances between neighboring points
+    """
+    num_samples, num_neighbors = neighbor_indices.shape
+    row_indices = np.repeat(np.arange(num_samples), num_neighbors)
+    return csr_matrix(
+        (neighbor_distances[:].flatten(), (row_indices, neighbor_indices[:].flatten())),
+        shape=(num_samples, num_samples),
+    )
+
+
+def calculate_weighted_cluster_similarity(
+    knn_graph: csr_matrix, cluster_labels: np.ndarray
+) -> np.ndarray:
+    """Calculate similarity between clusters based on shared weighted edges.
+
+    Uses a weighted Jaccard index to compute similarities between clusters in a KNN graph.
+
+    Args:
+        knn_graph: CSR matrix representing the KNN graph, shape (n_samples, n_samples)
+        cluster_labels: Cluster assignments for each node, must be contiguous integers
+            starting from 0
+
+    Returns:
+        np.ndarray: Symmetric matrix of shape (n_clusters, n_clusters) containing
+        pairwise similarities between clusters
+
+    Raises:
+        AssertionError: If cluster labels are not contiguous integers starting from 0
+    """
+    unique_cluster_ids = np.unique(cluster_labels)
+    expected_cluster_ids = np.arange(0, len(unique_cluster_ids))
+    assert np.array_equal(unique_cluster_ids, expected_cluster_ids), (
+        "Cluster labels must be contiguous integers starting at 1"
+    )
+
+    num_clusters = len(unique_cluster_ids)
+    inter_cluster_weights = np.zeros((num_clusters, num_clusters))
+
+    for cluster_id in unique_cluster_ids:
+        nodes_in_cluster = np.where(cluster_labels == cluster_id)[0]
+        neighbor_cluster_labels = cluster_labels[knn_graph[nodes_in_cluster].indices]
+        neighbor_edge_weights = knn_graph[nodes_in_cluster].data
+
+        for neighbor_cluster, edge_weight in zip(
+            neighbor_cluster_labels, neighbor_edge_weights
+        ):
+            inter_cluster_weights[cluster_id, neighbor_cluster] += edge_weight
+
+    # assert inter_cluster_weights.sum() == knn_graph.data.sum()
+
+    # Ensure symmetry
+    inter_cluster_weights = (inter_cluster_weights + inter_cluster_weights.T) / 2
+
+    # Calculate total weights for each cluster
+    total_cluster_weights = np.array(
+        [inter_cluster_weights[i - 1].sum() for i in unique_cluster_ids]
+    )
+
+    # Calculate similarity using weighted Jaccard index
+    similarity_matrix = np.zeros((num_clusters, num_clusters))
+
+    for i in range(num_clusters):
+        for j in range(i, num_clusters):
+            weight_union = (
+                total_cluster_weights[i]
+                + total_cluster_weights[j]
+                - inter_cluster_weights[i, j]
+            )
+            if weight_union > 0:
+                similarity = inter_cluster_weights[i, j] / weight_union
+                similarity_matrix[i, j] = similarity_matrix[j, i] = similarity
+
+    # Set diagonal to 1 (self-similarity)
+    # np.fill_diagonal(similarity_matrix, 1.0)
+
+    return similarity_matrix
+
+
+def calculate_top_k_neighbor_distances(
+    matrix_a: np.ndarray, matrix_b: np.ndarray, k: int
+) -> np.ndarray:
+    """Calculate distances to k nearest neighbors between two sets of points.
+
+    For each point in matrix_a, finds the k nearest neighbors in matrix_b
+    and returns their distances.
+
+    Args:
+        matrix_a: First set of points, shape (m, d)
+        matrix_b: Second set of points, shape (n, d)
+        k: Number of nearest neighbors to find
+
+    Returns:
+        np.ndarray: Matrix of shape (m, k) containing the distances to the
+        k nearest neighbors in matrix_b for each point in matrix_a
+
+    Raises:
+        AssertionError: If matrices don't have the same number of features
+    """
+    # Check if the matrices have the same number of features (d)
+    assert matrix_a.shape[1] == matrix_b.shape[1], (
+        "Matrices must have the same number of features"
+    )
+
+    # Ensure k is not larger than the number of points in matrix_b
+    k = min(k, matrix_b.shape[0])
+
+    # Calculate squared Euclidean distances
+    a_squared = np.sum(np.square(matrix_a), axis=1, keepdims=True)
+    b_squared = np.sum(np.square(matrix_b), axis=1)
+
+    # Use broadcasting to compute pairwise distances
+    distances = a_squared + b_squared - 2 * np.dot(matrix_a, matrix_b.T)
+
+    # Use np.maximum to avoid small negative values due to floating point errors
+    distances = np.maximum(distances, 0)
+
+    # Find the k smallest distances for each point in matrix_a
+    top_k_distances = np.partition(distances, k, axis=1)[:, :k]
+
+    # Calculate the square root to get Euclidean distances
+    return np.sqrt(top_k_distances)
+
+
+def process_cluster(
+    cluster_cells: np.ndarray,
+    hvg_data: Union[np.ndarray, zarrArrayType],
+    ann_obj: AnnStream,
+    k: int,
+) -> Tuple[np.ndarray, np.ndarray]:
+    """Process a cluster of cells to prepare data for silhouette scoring.
+
+    Randomly splits cluster cells into two groups and applies dimensionality reduction.
+
+    Args:
+        cluster_cells: Indices of cells belonging to the cluster
+        hvg_data: Expression data for highly variable genes
+        ann_obj: Object containing dimensionality reduction method
+        k: Number of cells to sample from cluster
+
+    Returns:
+        Tuple[np.ndarray, np.ndarray]: Two arrays containing reduced data for
+        different subsets of cells from the cluster
+    """
+    np.random.shuffle(cluster_cells)
+    data_cells = np.array(
+        [ann_obj.reducer(hvg_data[i]) for i in sorted(cluster_cells[:k])]
+    )
+    data_cells_2 = np.array(
+        [ann_obj.reducer(hvg_data[i]) for i in sorted(cluster_cells[k : 2 * k])]
+    )
+    return data_cells, data_cells_2
+
+
+def silhouette_scoring(
+    ds: DataStore,
+    ann_obj: AnnStream,
+    graph: csr_matrix,
+    hvg_data: Union[np.ndarray, zarrArrayType],
+    assay_type: str,
+    res_label: str,
+) -> Optional[np.ndarray]:
+    """Compute modified silhouette scores for clusters in single-cell data.
+
+    This implementation differs from the standard silhouette score by using
+    a graph-based approach and comparing clusters to their nearest neighbors.
+
+    Args:
+        ds: DataStore object containing cell metadata
+        ann_obj: Object containing dimensionality reduction method
+        graph: CSR matrix representing the KNN graph
+        hvg_data: Expression data for highly variable genes
+        assay_type: Type of assay (e.g., 'RNA', 'ATAC')
+        res_label: Label for clustering resolution
+
+    Returns:
+        Optional[np.ndarray]: Array of silhouette scores for each cluster,
+        or None if cluster labels are not found
+
+    Notes:
+        Scores are calculated using a sampling approach for efficiency.
+        NaN values indicate clusters that couldn't be scored due to size constraints.
+    """
+    try:
+        clusters = ds.cells.fetch(f"{assay_type}_{res_label}") - 1  # RNA_{res_label}
+    except KeyError:
+        logger.error(f"Cluster labels not found for {assay_type}_{res_label}")
+        return None
+
+    cluster_similarity = calculate_weighted_cluster_similarity(graph, clusters)
+
+    k = 11
+    score = []
+
+    for n, i in enumerate(cluster_similarity):
+        this_cluster_cells = np.where(clusters == n)[0]
+        if len(this_cluster_cells) < 2 * k:
+            k = int(len(this_cluster_cells) / 2)
+            logger.warning(
+                f"Warning: Cluster {n} has fewer than 22 cells. Will adjust k to {k} instead"
+            )
+
+    for n, i in tqdmbar(
+        enumerate(cluster_similarity),
+        total=len(cluster_similarity),
+        desc="Calculating Silhouette Scores",
+    ):
+        this_cluster_cells = np.where(clusters == n)[0]
+        np.random.shuffle(this_cluster_cells)
+        data_this_cells, data_this_cells_2 = process_cluster(
+            # n,
+            this_cluster_cells,
+            hvg_data,
+            ann_obj,
+            k,
+        )
+
+        if data_this_cells.size == 0 or data_this_cells_2.size == 0:
+            logger.warning(f"Warning: Reduced data for cluster {n} is empty. Skipping.")
+            score.append(np.nan)
+            continue
+
+        k_neighbors = min(k - 1, data_this_cells_2.shape[0] - 1)
+
+        if k_neighbors < 1:
+            logger.warning(
+                f"Warning: Not enough points in cluster {n} for comparison. Skipping."
+            )
+            score.append(np.nan)
+            continue
+
+        self_dist = calculate_top_k_neighbor_distances(
+            data_this_cells, data_this_cells_2, k - 1
+        ).mean()
+
+        nearest_cluster = np.argsort(i)[-1]
+        nearest_cluster_cells = np.where(clusters == nearest_cluster)[0]
+        np.random.shuffle(nearest_cluster_cells)
+
+        if len(nearest_cluster_cells) < k:
+            logger.warning(
+                f"Warning: Nearest cluster {nearest_cluster} has fewer than {k} cells. Skipping."
+            )
+            score.append(np.nan)
+            continue
+
+        data_nearest_cells, _ = process_cluster(
+            # nearest_cluster,
+            nearest_cluster_cells,
+            hvg_data,
+            ann_obj,
+            k,
+        )
+
+        if data_nearest_cells.size == 0:
+            logger.warning(
+                f"Warning: Reduced data for nearest cluster {nearest_cluster} is empty. Skipping."
+            )
+            score.append(np.nan)
+            continue
+
+        other_dist = calculate_top_k_neighbor_distances(
+            data_this_cells, data_nearest_cells, k - 1
+        ).mean()
+
+        score.append((other_dist - self_dist) / max(self_dist, other_dist))
+
+    return np.array(score)
+
+
+def integration_score(
+    batch_labels: Sequence[np.ndarray], metric: str = "ari"
+) -> Optional[float]:
+    """Calculate integration score between two sets of batch labels.
+
+    Args:
+        batch_labels: Sequence containing two arrays of batch labels to compare
+        metric: Metric to use for comparison, one of:
+            - 'ari': Adjusted Rand Index
+            - 'nmi': Normalized Mutual Information
+
+    Returns:
+        Optional[float]: Integration score between 0 and 1, or None if metric
+        is not recognized
+
+    Notes:
+        Higher scores indicate better agreement between batch labels,
+        suggesting more effective batch integration.
+    """
+    from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
+
+    if metric == "ari":
+        return adjusted_rand_score(batch_labels[0], batch_labels[1])
+    elif metric == "nmi":
+        return normalized_mutual_info_score(batch_labels[0], batch_labels[1])
+    else:
+        logger.error(
+            f"Metric {metric} not recognized. Please choose from 'ari', 'nmi', 'calinski_harabasz', or 'davies_bouldin'."
+        )
+        return None
diff --git a/scarf/tests/test_metrics.py b/scarf/tests/test_metrics.py
new file mode 100644
index 0000000..6cbb277
--- /dev/null
+++ b/scarf/tests/test_metrics.py
@@ -0,0 +1,49 @@
+import numpy as np
+
+
+def test_metric_lisi(datastore, make_graph):
+    lables = np.random.randint(0, 2, datastore.cells.N)
+    datastore.cells.insert(
+        column_name="samples_id",
+        values=lables,
+        overwrite=True,
+    )
+    lisi = datastore.metric_lisi(
+        label_colnames=["samples_id"], save_result=False, return_lisi=True
+    )
+    assert len(lisi[0][1]) == len(datastore.cells.active_index("I"))
+
+
+def test_metric_silhouette(datastore, make_graph, leiden_clustering):
+    _ = datastore.metric_silhouette()
+
+
+def test_metric_integration_ari(datastore, make_graph, leiden_clustering):
+    lables1 = np.random.randint(0, 2, datastore.cells.N)
+    lables2 = np.random.randint(0, 2, datastore.cells.N)
+    datastore.cells.insert(
+        column_name="lables1",
+        values=lables1,
+        overwrite=True,
+    )
+    datastore.cells.insert(
+        column_name="lables2",
+        values=lables2,
+        overwrite=True,
+    )
+    _ = datastore.metric_integration(batch_labels=["lables1", "lables2"], metric="ari")
+
+def test_metric_integration_nmi(datastore, make_graph, leiden_clustering):
+    lables1 = np.random.randint(0, 2, datastore.cells.N)
+    lables2 = np.random.randint(0, 2, datastore.cells.N)
+    datastore.cells.insert(
+        column_name="lables1",
+        values=lables1,
+        overwrite=True,
+    )
+    datastore.cells.insert(
+        column_name="lables2",
+        values=lables2,
+        overwrite=True,
+    )
+    _ = datastore.metric_integration(batch_labels=["lables1", "lables2"], metric="nmi")
\ No newline at end of file