forked from theislab/cellrank
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_external.py
104 lines (82 loc) · 3.51 KB
/
test_external.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import Optional
import pytest
import cellrank.external as cre
from anndata import AnnData
from cellrank.kernels import ConnectivityKernel
import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
def _wot_not_installed() -> bool:
try:
import wot
return False
except ImportError:
return True
def _statot_not_installed() -> bool:
try:
import statot
return False
except ImportError:
return True
wot_not_installed_skip = pytest.mark.skipif(
_wot_not_installed(), reason="WOT is not installed."
)
statot_not_installed_skip = pytest.mark.skipif(
_statot_not_installed(), reason="statOT is not installed."
)
@statot_not_installed_skip
class TestOTKernel:
def test_method_not_implemented(self, adata_large: AnnData):
terminal_states = np.full((adata_large.n_obs,), fill_value=np.nan, dtype=object)
ixs = np.where(adata_large.obs["clusters"] == "Granule immature")[0]
terminal_states[ixs] = "GI"
ok = cre.kernels.StationaryOTKernel(
adata_large,
terminal_states=pd.Series(terminal_states).astype("category"),
g=np.ones((adata_large.n_obs,), dtype=np.float64),
)
with pytest.raises(
NotImplementedError, match="Method `'unbal'` is not yet implemented."
):
ok.compute_transition_matrix(1, 0.001, method="unbal")
def test_no_terminal_states(self, adata_large: AnnData):
with pytest.raises(RuntimeError, match="Unable to initialize the kernel."):
cre.kernels.StationaryOTKernel(
adata_large,
g=np.ones((adata_large.n_obs,), dtype=np.float64),
)
def test_normal_run(self, adata_large: AnnData):
terminal_states = np.full((adata_large.n_obs,), fill_value=np.nan, dtype=object)
ixs = np.where(adata_large.obs["clusters"] == "Granule immature")[0]
terminal_states[ixs] = "GI"
ok = cre.kernels.StationaryOTKernel(
adata_large,
terminal_states=pd.Series(terminal_states).astype("category"),
g=np.ones((adata_large.n_obs,), dtype=np.float64),
)
ok = ok.compute_transition_matrix(1, 0.001)
assert isinstance(ok, cre.kernels.StationaryOTKernel)
assert isinstance(ok._transition_matrix, csr_matrix)
np.testing.assert_allclose(ok.transition_matrix.sum(1), 1.0)
assert isinstance(ok.params, dict)
@pytest.mark.parametrize("connectivity_kernel", (None, ConnectivityKernel))
def test_compute_projection(
self, adata_large: AnnData, connectivity_kernel: Optional[ConnectivityKernel]
):
terminal_states = np.full((adata_large.n_obs,), fill_value=np.nan, dtype=object)
ixs = np.where(adata_large.obs["clusters"] == "Granule immature")[0]
terminal_states[ixs] = "GI"
ok = cre.kernels.StationaryOTKernel(
adata_large,
terminal_states=pd.Series(terminal_states).astype("category"),
g=np.ones((adata_large.n_obs,), dtype=np.float64),
)
ok = ok.compute_transition_matrix(1, 0.001)
if connectivity_kernel is not None:
ck = connectivity_kernel(adata_large).compute_transition_matrix()
combined_kernel = 0.9 * ok + 0.1 * ck
combined_kernel.compute_transition_matrix()
combined_kernel.plot_projection()
else:
with pytest.raises(RuntimeError, match=r"Unable to find connectivities"):
ok.plot_projection()