Skip to content

Commit

Permalink
make it opt-in
Browse files Browse the repository at this point in the history
  • Loading branch information
flying-sheep committed Jan 21, 2025
1 parent 849aa19 commit bc060df
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 49 deletions.
13 changes: 8 additions & 5 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,8 +643,8 @@ def _handle_transformer(
use_dense_distances = (
kwds["metric"] == "euclidean" and self._adata.n_obs < 8192
) or not knn
shortcut = transformer is None and (
use_dense_distances or self._adata.n_obs < 4096
shortcut = transformer in {"sklearn", "sklearn-pairwise"} or (
transformer is None and (use_dense_distances or self._adata.n_obs < 4096)
)

# Coerce `method` to 'gauss' or 'umap'
Expand All @@ -668,16 +668,19 @@ def _handle_transformer(
raise ValueError(msg)

# Coerce `transformer` to an instance
if shortcut or transformer == "sklearn":
if shortcut:
from sklearn.neighbors import KNeighborsTransformer

assert transformer in {None, "sklearn"}
assert transformer in {None, "sklearn", "sklearn-pairwise"}
n_neighbors = self._adata.n_obs - 1
if knn: # only obey n_neighbors arg if knn set
n_neighbors = min(n_neighbors, kwds["n_neighbors"])

# sklearn-pairwise is opt-in, because it takes more memory
transformer_cls = (
PairwiseDistancesTransformer if shortcut else KNeighborsTransformer
PairwiseDistancesTransformer
if transformer == "sklearn-pairwise"
else KNeighborsTransformer
)

transformer = transformer_cls(
Expand Down
2 changes: 1 addition & 1 deletion src/scanpy/neighbors/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# These two are used with get_literal_vals elsewhere
_Method = Literal["umap", "gauss"]
_KnownTransformer = Literal["pynndescent", "sklearn", "rapids"]
_KnownTransformer = Literal["pynndescent", "sklearn", "sklearn-pairwise", "rapids"]

# sphinx-autodoc-typehints can’t transitively import types from if TYPE_CHECKING blocks,
# so these four needs to be importable
Expand Down
Binary file removed tests/_data/neighbors_shortcut_ref.h5ad
Binary file not shown.
43 changes: 0 additions & 43 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from __future__ import annotations

import warnings
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import pytest
from anndata import AnnData
from scipy import sparse
from scipy.sparse import csr_matrix, issparse
from sklearn.neighbors import KNeighborsTransformer

Expand All @@ -21,9 +19,6 @@
from pytest_mock import MockerFixture


DATA_DIR = Path(__file__).parent / "_data"


# the input data
X = [[1, 0], [3, 0], [5, 6], [0, 4]]
n_neighbors = 3 # includes data points themselves
Expand Down Expand Up @@ -242,41 +237,3 @@ def test_restore_n_neighbors(neigh: Neighbors, conv):
ad.uns["neighbors"] = dict(connectivities=conv(neigh.connectivities))
neigh_restored = Neighbors(ad)
assert neigh_restored.n_neighbors == 1


def test_regression_shortcut(monkeypatch: pytest.MonkeyPatch):
from scanpy.neighbors._backends import pairwise

monkeypatch.setattr(pairwise, "_DEBUG", True)
adata_ref = sc.read_h5ad(DATA_DIR / "neighbors_shortcut_ref.h5ad")

adata = AnnData(shape=(100, 5), obsm=adata_ref.obsm)
sc.pp.neighbors(adata, use_rep="normalized_X", random_state=0, n_neighbors=20)

mats: dict[
Literal["distances", "connectivities"], tuple[np.ndarray, np.ndarray]
] = {
key: tuple(ad.obsp[key].toarray() for ad in [adata, adata_ref])
for key in ["distances", "connectivities"]
}

assert_allclose(*mats["distances"], rtol=1e-7, atol=1e-7)
assert_allclose(*mats["connectivities"], rtol=1e-7, atol=1e-7)


def assert_allclose(
a: np.ndarray, b: np.ndarray, *, rtol: float = 1e-7, atol: float = 0
) -> None:
diff = a - b
diff[np.isclose(a, b, rtol=rtol, atol=atol)] = 0
diff = sparse.coo_matrix(diff)
diff.eliminate_zeros()

msg_nnz = f"{diff.getnnz(0)=}\n{diff.getnnz(1)=}"
msg_elems = "\n".join(
f"(a-b)[{i:2}, {j:2}] = {d:.8f}" for i, j, d in zip(*diff.coords, diff.data)
)

np.testing.assert_allclose(
a, b, rtol=rtol, atol=atol, err_msg=f"{msg_nnz}\n{msg_elems}"
)

0 comments on commit bc060df

Please sign in to comment.