Skip to content

Commit

Permalink
added DensMAP code from UMAP
Browse files Browse the repository at this point in the history
  • Loading branch information
parashardhapola committed Jul 30, 2021
1 parent bc902c6 commit 177480f
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 1 deletion.
45 changes: 44 additions & 1 deletion scarf/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,7 @@ def load_graph(
symmetric: bool,
upper_only: bool,
use_k: int = None,
graph_loc: str = None,
) -> csr_matrix:
"""
Load the cell neighbourhood as a scipy sparse matrix
Expand All @@ -1438,6 +1439,8 @@ def load_graph(
used when symmetric is True.
use_k: Number of top k-nearest neighbours to keep in the graph. This value must be greater than 0 and less
the parameter k used. By default all neighbours are used. (Default value: None)
graph_loc: Zarr hierarchy where the graph is stored. If no value is provided then graph location is
obtained from `_get_latest_graph_loc` method.
Returns:
A scipy sparse matrix representing cell neighbourhood graph.
Expand All @@ -1451,7 +1454,8 @@ def symmetrize(g):

from scipy.sparse import triu

graph_loc = self._get_latest_graph_loc(from_assay, cell_key, feat_key)
if graph_loc is None:
graph_loc = self._get_latest_graph_loc(from_assay, cell_key, feat_key)
if graph_loc not in self.z:
raise ValueError(
f"{graph_loc} not found in zarr location {self._fn}. "
Expand Down Expand Up @@ -1621,6 +1625,10 @@ def run_umap(
repulsion_strength: float = 1.0,
initial_alpha: float = 1.0,
negative_sample_rate: float = 5,
use_density_map: bool = False,
dens_lambda: float = 2.0,
dens_frac: float = 0.3,
dens_var_shift: float = 0.1,
random_seed: int = 4444,
label="UMAP",
parallel: bool = False,
Expand Down Expand Up @@ -1659,6 +1667,10 @@ def run_umap(
select per positive sample in the optimization process. Increasing this value will
result in greater repulsive force being applied, greater optimization cost, but
slightly more accuracy. (Default value: 5)
use_density_map:
dens_lambda:
dens_frac:
dens_var_shift:
random_seed: (Default value: 4444)
label: base label for UMAP dimensions in the cell metadata column (Default value: 'UMAP')
parallel: Whether to run UMAP in parallel mode. Setting value to True will use `nthreads` threads.
Expand Down Expand Up @@ -1691,6 +1703,35 @@ def run_umap(
verbose = False
if get_log_level() <= 20:
verbose = True

if use_density_map:
graph_loc = self._get_latest_graph_loc(from_assay, cell_key, feat_key)
knn_loc = graph_loc.rsplit("/", 1)[0]
logger.trace(f"Loading KNN dists and indices from {knn_loc}")
dists = self.z[knn_loc].distances[:]
indices = self.z[knn_loc].indices[:]
dmat = csr_matrix(
(
dists.flatten(),
(
np.repeat(range(indices.shape[0]), indices.shape[1]),
indices.flatten(),
),
),
shape=(indices.shape[0], indices.shape[0]),
)
# dmat = dmat.maximum(dmat.transpose()).todok()
logger.trace(f"Created sparse KNN dists and indices")
densmap_kwds = {
"lambda": dens_lambda,
"frac": dens_frac,
"var_shift": dens_var_shift,
"n_neighbors": dists.shape[1],
"knn_dists": dmat,
}
else:
densmap_kwds = {}

t, a, b = fit_transform(
graph=graph.tocoo(),
ini_embed=ini_embed,
Expand All @@ -1701,6 +1742,7 @@ def run_umap(
repulsion_strength=repulsion_strength,
initial_alpha=initial_alpha,
negative_sample_rate=negative_sample_rate,
densmap_kwds=densmap_kwds,
parallel=parallel,
nthreads=nthreads,
verbose=verbose,
Expand Down Expand Up @@ -2735,6 +2777,7 @@ def run_unified_umap(
repulsion_strength=repulsion_strength,
initial_alpha=initial_alpha,
negative_sample_rate=negative_sample_rate,
densmap_kwds={},
parallel=parallel,
nthreads=nthreads,
verbose=verbose,
Expand Down
43 changes: 43 additions & 0 deletions scarf/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@
__all__ = ["fit_transform"]


def calc_dens_map_params(graph, dists):
import numpy as np

n_vertices = graph.shape[0]
mu_sum = np.zeros(n_vertices, dtype=np.float32)
ro = np.zeros(n_vertices, dtype=np.float32)
head = graph.row
tail = graph.col
for i in range(len(head)):
j = head[i]
k = tail[i]

D = dists[j, k] * dists[j, k] # match sq-Euclidean used for embedding
mu = graph.data[i]

ro[j] += mu * D
ro[k] += mu * D
mu_sum[j] += mu
mu_sum[k] += mu

epsilon = 1e-8
ro = np.log(epsilon + (ro / mu_sum))
R = (ro - np.mean(ro)) / np.std(ro)
return mu_sum, R


def simplicial_set_embedding(
g,
embedding,
Expand All @@ -21,6 +47,7 @@ def simplicial_set_embedding(
gamma,
initial_alpha,
negative_sample_rate,
densmap_kwds,
parallel,
nthreads,
verbose,
Expand All @@ -47,6 +74,17 @@ def simplicial_set_embedding(
if numba.config.NUMBA_NUM_THREADS > nthreads:
numba.set_num_threads(nthreads)

if densmap_kwds != {}:
with threadpool_limits(limits=nthreads):
mu_sum, R = calc_dens_map_params(g, densmap_kwds["knn_dists"])
densmap_kwds["mu_sum"] = mu_sum
densmap_kwds["R"] = R
densmap_kwds["mu"] = g.data
densmap = True
logger.trace("calculated densmap params")
else:
densmap = False

# tqdm will be activated if https://github.com/lmcinnes/umap/pull/739
# is merged and when it is released
tqdm_params = dict(tqdm_params)
Expand All @@ -71,6 +109,8 @@ def simplicial_set_embedding(
negative_sample_rate,
parallel=parallel,
verbose=verbose,
densmap=densmap,
densmap_kwds=densmap_kwds,
# tqdm_kwds=tqdm_params,
)
return embedding
Expand All @@ -95,6 +135,7 @@ def fit_transform(
repulsion_strength,
initial_alpha,
negative_sample_rate,
densmap_kwds,
parallel,
nthreads,
verbose,
Expand All @@ -104,6 +145,7 @@ def fit_transform(
a, b = find_ab_params(spread, min_dist)
logger.trace("Found ab params")
# sym_graph = fuzzy_simplicial_set(graph, set_op_mix_ratio)

embedding = simplicial_set_embedding(
graph,
ini_embed,
Expand All @@ -114,6 +156,7 @@ def fit_transform(
repulsion_strength,
initial_alpha,
negative_sample_rate,
densmap_kwds,
parallel,
nthreads,
verbose,
Expand Down

0 comments on commit 177480f

Please sign in to comment.