Skip to content

Commit

Permalink
Merge pull request #552 from Sichao25/neighbors
Browse files Browse the repository at this point in the history
Update neighbors
  • Loading branch information
Xiaojieqiu authored Oct 19, 2023
2 parents a662048 + aa36f2d commit a991f52
Show file tree
Hide file tree
Showing 12 changed files with 294 additions and 320 deletions.
80 changes: 44 additions & 36 deletions dynamo/prediction/fate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
main_info_insert_adata,
main_warning,
)
from ..tools.connectivity import correct_hnsw_neighbors, k_nearest_neighbors
from ..tools.utils import fetch_states, getTseq
from ..vectorfield import vector_field_function
from ..vectorfield.utils import vecfld_from_adata, vector_transformation
Expand Down Expand Up @@ -406,24 +407,17 @@ def fate_bias(

X = adata.obsm[basis_key] if basis_key != "X" else adata.X

if X.shape[0] > 5000 and X.shape[1] > 2:
alg = "NNDescent"
from pynndescent import NNDescent

nbrs = NNDescent(
X,
metric=metric,
metric_kwds=metric_kwds,
n_neighbors=30,
n_jobs=cores,
random_state=seed,
**kwargs,
)
knn, distances = nbrs.query(X, k=30)
else:
alg = "ball_tree" if X.shape[1] > 10 else "kd_tree"
nbrs = NearestNeighbors(n_neighbors=30, algorithm=alg, n_jobs=cores).fit(X)
distances, knn = nbrs.kneighbors(X)
knn, distances, nbrs, alg = k_nearest_neighbors(
X,
k=29,
metric=metric,
metric_kwads=metric_kwds,
exclude_self=False,
pynn_rand_state=seed,
return_nbrs=True,
n_jobs=cores,
**kwargs,
)

median_dist = np.median(distances[:, 1])

Expand Down Expand Up @@ -455,8 +449,13 @@ def fate_bias(
main_info("using all steps data")
indices = np.arange(0, n_steps)

if alg == "NNDescent":
if alg == "pynn":
knn, distances = nbrs.query(prediction[:, indices].T, k=30)
elif alg == "hnswlib":
knn, distances = nbrs.knn_query(prediction[:, indices].T, k=30)
if metric == "euclidean":
distances = np.sqrt(distances)
knn, distances = correct_hnsw_neighbors(knn, distances)
else:
distances, knn = nbrs.kneighbors(prediction[:, indices].T)

Expand All @@ -470,6 +469,9 @@ def fate_bias(
# cells with indices are all close to some random progenitor cells.
if hasattr(nbrs, "query"):
knn, _ = nbrs.query(X[knn.flatten(), :], k=30)
elif hasattr(nbrs, "knn_query"):
knn, distances_hn = nbrs.knn_query(X[knn.flatten(), :], k=30)
knn, _ = correct_hnsw_neighbors(knn, distances_hn)
else:
_, knn = nbrs.kneighbors(X[knn.flatten(), :])

Expand All @@ -496,6 +498,11 @@ def fate_bias(

if hasattr(nbrs, "query"):
knn, distances = nbrs.query(prediction[:, indices - 1].T, k=30)
elif hasattr(nbrs, "knn_query"):
knn, distances = nbrs.knn_query(prediction[:, indices - 1].T, k=30)
if metric == "euclidean":
distances = np.sqrt(distances)
knn, distances = correct_hnsw_neighbors(knn, distances)
else:
distances, knn = nbrs.kneighbors(prediction[:, indices - 1].T)

Expand Down Expand Up @@ -594,22 +601,18 @@ def andecestor(
X = adata.obsm[basis_key].copy()

main_info("build a kNN graph structure so we can query the nearest cells of the predicted states.")
if X.shape[0] > 5000 and X.shape[1] > 2:
alg = "NNDescent"
from pynndescent import NNDescent

nbrs = NNDescent(
X,
metric=metric,
metric_kwds=metric_kwds,
n_neighbors=n_neighbors,
n_jobs=cores,
random_state=seed,
**kwargs,
)
else:
alg = "ball_tree" if X.shape[1] > 10 else "kd_tree"
nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm=alg, n_jobs=cores).fit(X)
_, _, nbrs, alg = k_nearest_neighbors(
X,
k=n_neighbors - 1,
metric=metric,
metric_kwads=metric_kwds,
exclude_self=False,
pynn_rand_state=seed,
n_jobs=cores,
return_nbrs=True,
logger=logger,
**kwargs,
)

if init_states is None:
init_states = adata[init_cells, :].obsm[basis_key]
Expand Down Expand Up @@ -639,8 +642,13 @@ def andecestor(
last_indices = [0, -1] if direction == "both" else [-1]
queries = pred[j].T[last_indices] if last_point_only else pred[j].T

if alg == "NNDescent":
if alg == "pynn":
knn, distances = nbrs.query(queries, k=n_neighbors)
elif alg == "hnswlib":
knn, distances = nbrs.knn_query(queries, k=n_neighbors)
if metric == "euclidean":
distances = np.sqrt(distances)
knn, distances = correct_hnsw_neighbors(knn, distances)
else:
distances, knn = nbrs.kneighbors(queries)

Expand Down
91 changes: 28 additions & 63 deletions dynamo/tools/Markov.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sklearn.neighbors import NearestNeighbors
from tqdm import tqdm

from .connectivity import k_nearest_neighbors
from ..dynamo_logger import LoggerManager, main_warning
from ..simulation.utils import directMethod
from .utils import append_iterative_neighbor_indices, flatten
Expand Down Expand Up @@ -164,25 +165,16 @@ def makeTransitionMatrix(Qnn, I_vec, tol=0.0): # Qnn, I, tol=0.0
return M


@jit(nopython=True)

def compute_tau(X, V, k=100, nbr_idx=None):
if nbr_idx is None:
if X.shape[0] > 200000 and X.shape[1] > 2:
from pynndescent import NNDescent

nbrs = NNDescent(
X,
metric="euclidean",
n_neighbors=k,
n_jobs=-1,
random_state=19491001,
)
_, dist = nbrs.query(X, k=k)
else:
alg = "ball_tree" if X.shape[1] > 10 else "kd_tree"
nbrs = NearestNeighbors(n_neighbors=k, algorithm=alg, n_jobs=-1).fit(X)
dists, _ = nbrs.kneighbors(X)

_, dists = k_nearest_neighbors(
X,
k=k - 1,
exclude_self=False,
pynn_rand_state=19491001,
n_jobs=-1,
)
else:
dists = np.zeros(nbr_idx.shape)
for i in range(nbr_idx.shape[0]):
Expand Down Expand Up @@ -225,22 +217,13 @@ def prepare_velocity_grid_data(
if n_neighbors is None:
n_neighbors = np.max([10, int(n_obs / 50)])

if X_emb.shape[0] > 200000 and X_emb.shape[1] > 2:
from pynndescent import NNDescent

nn = NNDescent(
X_emb,
metric="euclidean",
n_neighbors=n_neighbors,
n_jobs=-1,
random_state=19491001,
)
neighs, dists = nn.query(X_grid, k=n_neighbors)
else:
alg = "ball_tree" if X_emb.shape[1] > 10 else "kd_tree"
nn = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=-1, algorithm=alg)
nn.fit(X_emb)
dists, neighs = nn.kneighbors(X_grid)
neighs, dists = k_nearest_neighbors(
X_emb,
query_X=X_grid,
k=n_neighbors - 1,
exclude_self=False,
pynn_rand_state=19491001,
)

weight = norm.pdf(x=dists, scale=scale)
p_mass = weight.sum(1)
Expand Down Expand Up @@ -390,21 +373,12 @@ def graphize_velocity(V, X, nbrs_idx=None, k=30, normalize_v=False, E_func=None)

nbrs = None
if nbrs_idx is None:
if n > 200000 and d > 2:
from pynndescent import NNDescent

nbrs = NNDescent(
X,
metric="euclidean",
n_neighbors=k + 1,
n_jobs=-1,
random_state=19491001,
)
nbrs_idx, _ = nbrs.query(X, k=k + 1)
else:
alg = "ball_tree" if d > 10 else "kd_tree"
nbrs = NearestNeighbors(n_neighbors=k + 1, algorithm=alg, n_jobs=-1).fit(X)
_, nbrs_idx = nbrs.kneighbors(X)
nbrs_idx, _ = k_nearest_neighbors(
X,
k=k,
exclude_self=False,
pynn_rand_state=19491001,
)

if type(E_func) is str:
if E_func == "sqrt":
Expand Down Expand Up @@ -594,21 +568,12 @@ def fit(
):
# compute connectivity
if neighbor_idx is None:
if X.shape[0] > 200000 and X.shape[1] > 2:
from pynndescent import NNDescent

nbrs = NNDescent(
X,
metric="euclidean",
n_neighbors=k,
n_jobs=-1,
random_state=19491001,
)
neighbor_idx, _ = nbrs.query(X, k=k)
else:
alg = "ball_tree" if X.shape[1] > 10 else "kd_tree"
nbrs = NearestNeighbors(n_neighbors=k, algorithm=alg, n_jobs=-1).fit(X)
_, neighbor_idx = nbrs.kneighbors(X)
neighbor_idx, _ = k_nearest_neighbors(
X,
k=k-1,
exclude_self=False,
pynn_rand_state=19491001,
)

if n_recurse_neighbors is not None:
self.Idx = append_iterative_neighbor_indices(neighbor_idx, n_recurse_neighbors)
Expand Down
Loading

0 comments on commit a991f52

Please sign in to comment.