diff --git a/VERSION b/VERSION index a4b9c5a..beb3a3f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.28.8 +0.29.0 \ No newline at end of file diff --git a/scarf/datastore/datastore.py b/scarf/datastore/datastore.py index 0cc65dc..9e80b4a 100644 --- a/scarf/datastore/datastore.py +++ b/scarf/datastore/datastore.py @@ -1,13 +1,16 @@ from typing import Iterable, Optional, Union, List, Literal, Tuple + import numpy as np import pandas as pd from dask import array as daskarr from loguru import logger + from .mapping_datastore import MappingDatastore -from ..writers import create_zarr_obj_array, create_zarr_dataset -from ..utils import tqdmbar, controlled_compute, ZARRLOC +from .. import doublet_scoring from ..assay import RNAassay, ATACassay from ..feat_utils import hto_demux +from ..utils import tqdmbar, controlled_compute, ZARRLOC +from ..writers import create_zarr_obj_array, create_zarr_dataset __all__ = ["DataStore"] @@ -1208,6 +1211,44 @@ def smart_label( else: self.cells.insert(new_col_name, ret_val, overwrite=True) + def predict_doublet_score( + self, sim_doublet_ratio: float = 2, smoothen_iteration: int = 2 + ): + idx = doublet_scoring.get_simulated_pair_idx( + self.cells.fetch_all("I").sum(), sim_doublet_ratio=sim_doublet_ratio + ) + simulated_doublets = doublet_scoring.get_simulated_doublets(self, idx) + + # TODO: what should be the sim_dset_path + doublet_scoring.save_sim_doublets( + simulated_doublets, assay=self.RNA, idx=idx, sim_dset_path="workspace" + ) + sim_ds = DataStore("workspace") + doublet_scoring.process_sim_ds(sim_ds) + self.run_mapping( + target_assay=sim_ds.RNA, + target_name="sim_ds", + target_feat_key="hvgs_ctrl", + save_k=11, + run_coral=False, + ) + + # TODO: is this correct way of fetching? + _, ms = next( + self.get_mapping_score( + target_name="sim_ds", log_transform=False, multiplier=1e4 + ) + ) + + graph = self.load_graph() + doublet_scores = doublet_scoring.average_signal_by_neighbour( + graph.indices.reshape(graph.shape[0], 11).astype("int64"), + graph.data.reshape(graph.shape[0], 11).astype("float64"), + ms.astype("float64"), + t=smoothen_iteration, + ) + self.cells.insert("doublet_scores", doublet_scores, overwrite=True) + def plot_cells_dists( self, from_assay: Optional[str] = None, diff --git a/scarf/doublet_scoring.py b/scarf/doublet_scoring.py new file mode 100644 index 0000000..c047a06 --- /dev/null +++ b/scarf/doublet_scoring.py @@ -0,0 +1,111 @@ +import numpy as np +import pandas as pd +import zarr +from numba import jit + +from .utils import tqdmbar +from .writers import create_zarr_count_assay, create_cell_data + + +def get_simulated_pair_idx( + n_obs: int, + random_seed: int = 1, + sim_doublet_ratio: float = 2.0, + cell_range: int = 100, +): + """ + sim_doublet_ratio: ratio of simulated doublets. Default 2 + """ + n_sim = int(n_obs * sim_doublet_ratio) + + rng_state = np.random.RandomState(random_seed) + pair_1 = rng_state.randint(0, n_obs, size=n_sim) + # If we assume that the order of cells (indices) is truly random in the source dataset ie iid + # Then it doesn't matter what the search space for random cell is + pair_2 = np.array( + [ + rng_state.randint( + max(0, x - cell_range), min(n_obs - 1, x + cell_range), size=1 + )[0] + for x in pair_1 + ] + ) + + idx = np.array([pair_1, pair_2]).T + idx = pd.DataFrame(idx).sort_values(by=[0, 1]).values + + return idx + + +def get_simulated_doublets(ds, indexes: np.ndarray): + return ds.RNA.rawData[indexes[:, 0]] + ds.RNA.rawData[indexes[:, 1]] + + +def save_sim_doublets( + simulated_doublets, assay, idx, sim_dset_path: str, rechunk=False +) -> None: + zarr_path = zarr.open(sim_dset_path, "w") + + g = create_zarr_count_assay( + z=zarr_path, + assay_name=assay.name, + workspace=None, + chunk_size=assay.rawData.chunksize, + n_cells=simulated_doublets.shape[0], + feat_ids=assay.feats.fetch_all("ids"), + feat_names=assay.feats.fetch_all("names"), + ) + sim_cell_ids = np.array([f"Sim_{x[0]}-{x[1]}" for x in idx]).astype(object) + create_cell_data(z=zarr_path, workspace=None, ids=sim_cell_ids, names=sim_cell_ids) + if rechunk: + simulated_doublets = simulated_doublets.rechunk( + 1000, simulated_doublets.shape[1] + ) + + compute_write(simulated_doublets, g) + + +# TODO: is there built-in scarf function to replace this +def compute_write(simulated, zarr_array): + s, e = 0, 0 + batch = None + + # TODO: do we need tqdm here + for i in tqdmbar(simulated.blocks, total=simulated.numblocks[0]): + if batch is None: + batch = i.compute() + else: + batch = np.vstack([batch, i.compute()]) + if len(batch) > 1000: + e += batch.shape[0] + zarr_array[s:e] = batch + batch = None + s = e + + if batch is not None: + e += batch.shape[0] + zarr_array[s:e] = batch + + assert e == simulated.shape[0] + + +@jit(cache=True, nopython=True) +def average_signal_by_neighbour(inds, data, signal, t: int): + out = signal.copy() + n = out.shape[0] + for _ in range(t): + temp = np.zeros(n) + for i in range(n): + neighbour_w_mean = (out[inds[i]] * data[i]).mean() + temp[i] = (out[i] + neighbour_w_mean) / 2 + out = temp.copy() + + return out + + +def process_sim_ds(sim_ds): + sim_ds.mark_hvgs(min_cells=20, top_n=500, min_mean=-3, max_mean=2, max_var=6) + + sim_ds.make_graph( + feat_key="hvgs", k=11, dims=15, n_centroids=100, local_connectivity=0.9 + )