diff --git a/scarf/datastore.py b/scarf/datastore.py index 572b4bd..85ed6a4 100644 --- a/scarf/datastore.py +++ b/scarf/datastore.py @@ -1425,6 +1425,7 @@ def load_graph( symmetric: bool, upper_only: bool, use_k: int = None, + graph_loc: str = None, ) -> csr_matrix: """ Load the cell neighbourhood as a scipy sparse matrix @@ -1438,6 +1439,8 @@ def load_graph( used when symmetric is True. use_k: Number of top k-nearest neighbours to keep in the graph. This value must be greater than 0 and less the parameter k used. By default all neighbours are used. (Default value: None) + graph_loc: Zarr hierarchy where the graph is stored. If no value is provided then graph location is + obtained from `_get_latest_graph_loc` method. Returns: A scipy sparse matrix representing cell neighbourhood graph. @@ -1451,7 +1454,8 @@ def symmetrize(g): from scipy.sparse import triu - graph_loc = self._get_latest_graph_loc(from_assay, cell_key, feat_key) + if graph_loc is None: + graph_loc = self._get_latest_graph_loc(from_assay, cell_key, feat_key) if graph_loc not in self.z: raise ValueError( f"{graph_loc} not found in zarr location {self._fn}. " @@ -1621,6 +1625,10 @@ def run_umap( repulsion_strength: float = 1.0, initial_alpha: float = 1.0, negative_sample_rate: float = 5, + use_density_map: bool = False, + dens_lambda: float = 2.0, + dens_frac: float = 0.3, + dens_var_shift: float = 0.1, random_seed: int = 4444, label="UMAP", parallel: bool = False, @@ -1659,6 +1667,10 @@ def run_umap( select per positive sample in the optimization process. Increasing this value will result in greater repulsive force being applied, greater optimization cost, but slightly more accuracy. (Default value: 5) + use_density_map: + dens_lambda: + dens_frac: + dens_var_shift: random_seed: (Default value: 4444) label: base label for UMAP dimensions in the cell metadata column (Default value: 'UMAP') parallel: Whether to run UMAP in parallel mode. Setting value to True will use `nthreads` threads. @@ -1691,6 +1703,35 @@ def run_umap( verbose = False if get_log_level() <= 20: verbose = True + + if use_density_map: + graph_loc = self._get_latest_graph_loc(from_assay, cell_key, feat_key) + knn_loc = graph_loc.rsplit("/", 1)[0] + logger.trace(f"Loading KNN dists and indices from {knn_loc}") + dists = self.z[knn_loc].distances[:] + indices = self.z[knn_loc].indices[:] + dmat = csr_matrix( + ( + dists.flatten(), + ( + np.repeat(range(indices.shape[0]), indices.shape[1]), + indices.flatten(), + ), + ), + shape=(indices.shape[0], indices.shape[0]), + ) + # dmat = dmat.maximum(dmat.transpose()).todok() + logger.trace(f"Created sparse KNN dists and indices") + densmap_kwds = { + "lambda": dens_lambda, + "frac": dens_frac, + "var_shift": dens_var_shift, + "n_neighbors": dists.shape[1], + "knn_dists": dmat, + } + else: + densmap_kwds = {} + t, a, b = fit_transform( graph=graph.tocoo(), ini_embed=ini_embed, @@ -1701,6 +1742,7 @@ def run_umap( repulsion_strength=repulsion_strength, initial_alpha=initial_alpha, negative_sample_rate=negative_sample_rate, + densmap_kwds=densmap_kwds, parallel=parallel, nthreads=nthreads, verbose=verbose, @@ -2735,6 +2777,7 @@ def run_unified_umap( repulsion_strength=repulsion_strength, initial_alpha=initial_alpha, negative_sample_rate=negative_sample_rate, + densmap_kwds={}, parallel=parallel, nthreads=nthreads, verbose=verbose, diff --git a/scarf/umap.py b/scarf/umap.py index 0965ccf..18da085 100644 --- a/scarf/umap.py +++ b/scarf/umap.py @@ -11,6 +11,32 @@ __all__ = ["fit_transform"] +def calc_dens_map_params(graph, dists): + import numpy as np + + n_vertices = graph.shape[0] + mu_sum = np.zeros(n_vertices, dtype=np.float32) + ro = np.zeros(n_vertices, dtype=np.float32) + head = graph.row + tail = graph.col + for i in range(len(head)): + j = head[i] + k = tail[i] + + D = dists[j, k] * dists[j, k] # match sq-Euclidean used for embedding + mu = graph.data[i] + + ro[j] += mu * D + ro[k] += mu * D + mu_sum[j] += mu + mu_sum[k] += mu + + epsilon = 1e-8 + ro = np.log(epsilon + (ro / mu_sum)) + R = (ro - np.mean(ro)) / np.std(ro) + return mu_sum, R + + def simplicial_set_embedding( g, embedding, @@ -21,6 +47,7 @@ def simplicial_set_embedding( gamma, initial_alpha, negative_sample_rate, + densmap_kwds, parallel, nthreads, verbose, @@ -47,6 +74,17 @@ def simplicial_set_embedding( if numba.config.NUMBA_NUM_THREADS > nthreads: numba.set_num_threads(nthreads) + if densmap_kwds != {}: + with threadpool_limits(limits=nthreads): + mu_sum, R = calc_dens_map_params(g, densmap_kwds["knn_dists"]) + densmap_kwds["mu_sum"] = mu_sum + densmap_kwds["R"] = R + densmap_kwds["mu"] = g.data + densmap = True + logger.trace("calculated densmap params") + else: + densmap = False + # tqdm will be activated if https://github.com/lmcinnes/umap/pull/739 # is merged and when it is released tqdm_params = dict(tqdm_params) @@ -71,6 +109,8 @@ def simplicial_set_embedding( negative_sample_rate, parallel=parallel, verbose=verbose, + densmap=densmap, + densmap_kwds=densmap_kwds, # tqdm_kwds=tqdm_params, ) return embedding @@ -95,6 +135,7 @@ def fit_transform( repulsion_strength, initial_alpha, negative_sample_rate, + densmap_kwds, parallel, nthreads, verbose, @@ -104,6 +145,7 @@ def fit_transform( a, b = find_ab_params(spread, min_dist) logger.trace("Found ab params") # sym_graph = fuzzy_simplicial_set(graph, set_op_mix_ratio) + embedding = simplicial_set_embedding( graph, ini_embed, @@ -114,6 +156,7 @@ def fit_transform( repulsion_strength, initial_alpha, negative_sample_rate, + densmap_kwds, parallel, nthreads, verbose,