Skip to content

Commit

Permalink
working draft version of doublet detection
Browse files Browse the repository at this point in the history
  • Loading branch information
hi-ilkin committed Jan 5, 2024
1 parent 4fa7ce1 commit 20e0c96
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 2 deletions.
45 changes: 43 additions & 2 deletions scarf/datastore/datastore.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Iterable, Optional, Union, List, Literal, Tuple

import numpy as np
import pandas as pd
from dask import array as daskarr
from loguru import logger

from .mapping_datastore import MappingDatastore
from ..writers import create_zarr_obj_array, create_zarr_dataset
from ..utils import tqdmbar, controlled_compute, ZARRLOC
from .. import doublet_scoring
from ..assay import RNAassay, ATACassay
from ..feat_utils import hto_demux
from ..utils import tqdmbar, controlled_compute, ZARRLOC
from ..writers import create_zarr_obj_array, create_zarr_dataset

__all__ = ["DataStore"]

Expand Down Expand Up @@ -1208,6 +1211,44 @@ def smart_label(
else:
self.cells.insert(new_col_name, ret_val, overwrite=True)

def predict_doublet_score(
self, sim_doublet_ratio: float = 2, smoothen_iteration: int = 2
):
idx = doublet_scoring.get_simulated_pair_idx(
self.cells.fetch_all("I").sum(), sim_doublet_ratio=sim_doublet_ratio
)
simulated_doublets = doublet_scoring.get_simulated_doublets(self, idx)

# TODO: what should be the sim_dset_path
doublet_scoring.save_sim_doublets(
simulated_doublets, assay=self.RNA, idx=idx, sim_dset_path="workspace"
)
sim_ds = DataStore("workspace")
doublet_scoring.process_sim_ds(sim_ds)
self.run_mapping(
target_assay=sim_ds.RNA,
target_name="sim_ds",
target_feat_key="hvgs_ctrl",
save_k=11,
run_coral=False,
)

# TODO: is this correct way of fetching?
_, ms = next(
self.get_mapping_score(
target_name="sim_ds", log_transform=False, multiplier=1e4
)
)

graph = self.load_graph()
doublet_scores = doublet_scoring.average_signal_by_neighbour(
graph.indices.reshape(graph.shape[0], 11).astype("int64"),
graph.data.reshape(graph.shape[0], 11).astype("float64"),
ms.astype("float64"),
t=smoothen_iteration,
)
self.cells.insert("doublet_scores", doublet_scores, overwrite=True)

def plot_cells_dists(
self,
from_assay: Optional[str] = None,
Expand Down
111 changes: 111 additions & 0 deletions scarf/doublet_scoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np
import pandas as pd
import zarr
from numba import jit

from .utils import tqdmbar
from .writers import create_zarr_count_assay, create_cell_data


def get_simulated_pair_idx(
n_obs: int,
random_seed: int = 1,
sim_doublet_ratio: float = 2.0,
cell_range: int = 100,
):
"""
sim_doublet_ratio: ratio of simulated doublets. Default 2
"""
n_sim = int(n_obs * sim_doublet_ratio)

rng_state = np.random.RandomState(random_seed)
pair_1 = rng_state.randint(0, n_obs, size=n_sim)
# If we assume that the order of cells (indices) is truly random in the source dataset ie iid
# Then it doesn't matter what the search space for random cell is
pair_2 = np.array(
[
rng_state.randint(
max(0, x - cell_range), min(n_obs - 1, x + cell_range), size=1
)[0]
for x in pair_1
]
)

idx = np.array([pair_1, pair_2]).T
idx = pd.DataFrame(idx).sort_values(by=[0, 1]).values

return idx


def get_simulated_doublets(ds, indexes: np.ndarray):
return ds.RNA.rawData[indexes[:, 0]] + ds.RNA.rawData[indexes[:, 1]]


def save_sim_doublets(
simulated_doublets, assay, idx, sim_dset_path: str, rechunk=False
) -> None:
zarr_path = zarr.open(sim_dset_path, "w")

g = create_zarr_count_assay(
z=zarr_path,
assay_name=assay.name,
workspace=None,
chunk_size=assay.rawData.chunksize,
n_cells=simulated_doublets.shape[0],
feat_ids=assay.feats.fetch_all("ids"),
feat_names=assay.feats.fetch_all("names"),
)
sim_cell_ids = np.array([f"Sim_{x[0]}-{x[1]}" for x in idx]).astype(object)
create_cell_data(z=zarr_path, workspace=None, ids=sim_cell_ids, names=sim_cell_ids)
if rechunk:
simulated_doublets = simulated_doublets.rechunk(
1000, simulated_doublets.shape[1]
)

compute_write(simulated_doublets, g)


# TODO: is there built-in scarf function to replace this
def compute_write(simulated, zarr_array):
s, e = 0, 0
batch = None

# TODO: do we need tqdm here
for i in tqdmbar(simulated.blocks, total=simulated.numblocks[0]):
if batch is None:
batch = i.compute()
else:
batch = np.vstack([batch, i.compute()])
if len(batch) > 1000:
e += batch.shape[0]
zarr_array[s:e] = batch
batch = None
s = e

if batch is not None:
e += batch.shape[0]
zarr_array[s:e] = batch

assert e == simulated.shape[0]


@jit(cache=True, nopython=True)
def average_signal_by_neighbour(inds, data, signal, t: int):
out = signal.copy()
n = out.shape[0]
for _ in range(t):
temp = np.zeros(n)
for i in range(n):
neighbour_w_mean = (out[inds[i]] * data[i]).mean()
temp[i] = (out[i] + neighbour_w_mean) / 2
out = temp.copy()

return out


def process_sim_ds(sim_ds):
sim_ds.mark_hvgs(min_cells=20, top_n=500, min_mean=-3, max_mean=2, max_var=6)

sim_ds.make_graph(
feat_key="hvgs", k=11, dims=15, n_centroids=100, local_connectivity=0.9
)

0 comments on commit 20e0c96

Please sign in to comment.