From 709f11652442ab8c88fdfe19b28da22302a4e6c2 Mon Sep 17 00:00:00 2001 From: zhanghao-njmu <542370159@qq.com> Date: Mon, 11 Sep 2023 18:25:29 +0800 Subject: [PATCH] Add RunWOT function --- .Rbuildignore | 4 +++- .gitignore | 3 ++- NAMESPACE | 1 + R/SCP-analysis.R | 25 +++++++++++++++---------- R/SCP-workflow.R | 2 +- inst/python/SCP_analysis.py | 27 ++++++++++++++++++++++----- man/RunWOT.Rd | 5 +---- 7 files changed, 45 insertions(+), 22 deletions(-) diff --git a/.Rbuildignore b/.Rbuildignore index fb89475e..74c3bdb2 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -12,4 +12,6 @@ ^LICENSE\.md$ ^SCExplorer$ ^README$ -^README.md$ \ No newline at end of file +^README.md$ +^tmaps$ +^src/.*[^.cpp]$ \ No newline at end of file diff --git a/.gitignore b/.gitignore index 8a1a591d..f150a4e3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ SCExplorer docs renv -renv.lock \ No newline at end of file +renv.lock +tmaps \ No newline at end of file diff --git a/NAMESPACE b/NAMESPACE index 828eab44..62b6fd8b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -107,6 +107,7 @@ export(RunSlingshot) export(RunSymphonyMap) export(RunTriMap) export(RunUMAP2) +export(RunWOT) export(SankeyPlot) export(Scanorama_integrate) export(Seurat_integrate) diff --git a/R/SCP-analysis.R b/R/SCP-analysis.R index e1bddc13..3ff16c6c 100644 --- a/R/SCP-analysis.R +++ b/R/SCP-analysis.R @@ -5340,7 +5340,7 @@ maxDepth <- function(x, depth = 0) { #' @importFrom reticulate py_to_r #' @export check_python_element <- function(x, depth = maxDepth(x)) { - if (depth == 0 || !is.list(x)) { + if (depth == 0 || !is.list(x) || !inherits(x, "python.builtin.object")) { if (inherits(x, "python.builtin.object")) { x_r <- tryCatch(py_to_r(x), error = identity) if (inherits(x_r, "error")) { @@ -5702,11 +5702,11 @@ RunPalantir <- function(srt = NULL, assay_X = "RNA", slot_X = "counts", assay_la #' @inheritParams RunSCVELO #' #' @examples -#' @return A \code{anndata} object. +#' @export RunWOT <- function(srt = NULL, assay_X = "RNA", slot_X = "counts", assay_layers = c("spliced", "unspliced"), slot_layers = "counts", adata = NULL, group_by = NULL, time_field = "Time", growth_iters = 3L, tmap_out = "tmaps/tmap_out", - time_from = 1, time_to = NULL, get_coupling = FALSE, recalculate = FALSE, + time_from = NULL, time_to = NULL, get_coupling = FALSE, recalculate = FALSE, palette = "Paired", palcolor = NULL, show_plot = TRUE, dpi = 300, save = FALSE, dirpath = "./", fileprefix = "", return_seurat = !is.null(srt)) { @@ -5714,15 +5714,20 @@ RunWOT <- function(srt = NULL, assay_X = "RNA", slot_X = "counts", assay_layers if (all(is.null(srt), is.null(adata))) { stop("One of 'srt', 'adata' must be provided.") } - if (is.null(group_by) && any(!is.null(early_group), !is.null(terminal_groups))) { - stop("'group_by' must be provided when early_group or terminal_groups provided.") + if (is.null(group_by)) { + stop("'group_by' must be provided.") } - if (is.null(linear_reduction) && is.null(nonlinear_reduction)) { - stop("'linear_reduction' or 'nonlinear_reduction' must be provided at least one.") + if (is.null(time_field)) { + stop("'time_field' must be provided.") } - if (is.null(early_cell) && is.null(early_group)) { - stop("'early_cell' or 'early_group' must be provided.") + if (is.null(time_from)) { + stop("'time_from' must be provided.") } + if (isTRUE(get_coupling) && is.null(time_to)) { + + } + warning("The 'get_coupling' paramter is only valid when 'time_to' is specified.") + args <- mget(names(formals())) args <- lapply(args, function(x) { if (is.numeric(x)) { @@ -5755,7 +5760,7 @@ RunWOT <- function(srt = NULL, assay_X = "RNA", slot_X = "counts", assay_layers args[["palette"]] <- palette_scp(levels(groups) %||% unique(groups), palette = palette, palcolor = palcolor) SCP_analysis <- reticulate::import_from_path("SCP_analysis", path = system.file("python", package = "SCP", mustWork = TRUE), convert = TRUE) - adata <- do.call(SCP_analysis$WOT, args) + adata <- do.call(WOT, args) if (isTRUE(return_seurat)) { srt_out <- adata_to_srt(adata) diff --git a/R/SCP-workflow.R b/R/SCP-workflow.R index e8829e5b..59e53879 100644 --- a/R/SCP-workflow.R +++ b/R/SCP-workflow.R @@ -189,7 +189,7 @@ check_srtList <- function(srtList, batch = NULL, assay = NULL, } } if (status == "unknown") { - warning("Can not determine whether data ", i, " is log-normalized...\n", immediate. = TRUE) + warning("Can not determine whether data ", i, " is log-normalized...", immediate. = TRUE) } } if (is.null(HVF)) { diff --git a/inst/python/SCP_analysis.py b/inst/python/SCP_analysis.py index 3bc30c3b..01cd6de3 100644 --- a/inst/python/SCP_analysis.py +++ b/inst/python/SCP_analysis.py @@ -770,7 +770,7 @@ def Palantir(adata=None, h5ad=None,group_by=None,palette=None, def WOT(adata=None, h5ad=None, group_by=None,palette=None, time_field = "Time", growth_iters = 3, tmap_out = "tmaps/tmap_out", - time_from = 1, time_to = None, get_coupling = False,recalculate=False, + time_from = None, time_to = None, get_coupling = False,recalculate=False, show_plot=True, dpi=300, save=False, dirpath="./", fileprefix=""): import matplotlib.pyplot as plt import scanpy as sc @@ -778,6 +778,7 @@ def WOT(adata=None, h5ad=None, group_by=None,palette=None, import statistics import pandas as pd from math import hypot + import wot import warnings warnings.simplefilter("ignore", category=UserWarning) @@ -810,6 +811,10 @@ def WOT(adata=None, h5ad=None, group_by=None,palette=None, if group_by is None: print("group_by must be provided.") exit() + + if time_field is None: + print("time_field must be provided.") + exit() adata.obs[group_by] = adata.obs[group_by].astype(dtype = "category") if pd.api.types.is_categorical_dtype(adata.obs[time_field]): @@ -823,7 +828,11 @@ def WOT(adata=None, h5ad=None, group_by=None,palette=None, adata.obs["time_field"] = adata.obs[time_field] time_dict = dict(zip(adata.obs[time_field],adata.obs["time_field"])) - ot_model <- wot.ot.OTModel(adata, growth_iters = growth_iters, day_field = "time_field") + if time_from not in time_dict.keys(): + print("'time_from' is incorrect") + exit() + + ot_model = wot.ot.OTModel(adata, growth_iters = growth_iters, day_field = "time_field") if recalculate is True: ot_model.compute_all_transport_maps(tmap_out = tmap_out) @@ -831,7 +840,7 @@ def WOT(adata=None, h5ad=None, group_by=None,palette=None, else: try: tmap_model = wot.tmap.TransportMapModel.from_directory(tmap_out) - except ValueError: + except (FileNotFoundError, ValueError): ot_model.compute_all_transport_maps(tmap_out = tmap_out) tmap_model = wot.tmap.TransportMapModel.from_directory(tmap_out) @@ -845,15 +854,23 @@ def WOT(adata=None, h5ad=None, group_by=None,palette=None, trajectory_ds = tmap_model.trajectories(from_populations) trajectory_df = pd.DataFrame(trajectory_ds.X, index=trajectory_ds.obs_names, columns=trajectory_ds.var_names) - adata.uns["trajectory_"+str(time_from)]= trajectory_df + adata.uns["trajectory_"+str(time_from)]= trajectory_df.reindex(adata.obs_names) fates_ds = tmap_model.fates(from_populations) fates_df = pd.DataFrame(fates_ds.X, index=fates_ds.obs_names, columns=fates_ds.var_names) - adata.uns["fates_"+str(time_from)]= fates_df + existing_rows = fates_df.index.tolist() + new_rows = list(set(adata.obs_names) - set(existing_rows)) + new_df = pd.DataFrame(0, index=new_rows, columns=fates_df.columns) + fates_df = pd.concat([fates_df, new_df]) + adata.uns["fates_"+str(time_from)]= fates_df.reindex(adata.obs_names) # obs_list = wot.tmap.trajectory_trends_from_trajectory(trajectory_ds = trajectory_ds, expression_ds = adata) if time_to is not None: + if time_to not in time_dict.keys(): + print("'time_to' is incorrect") + exit() + to_populations = tmap_model.population_from_cell_sets(cell_sets, at_time = time_dict[time_to]) transition_table = tmap_model.transition_table(from_populations, to_populations) transition_df = pd.DataFrame(transition_table.X, index=transition_table.obs_names, columns=transition_table.var_names) diff --git a/man/RunWOT.Rd b/man/RunWOT.Rd index c711be59..e9f3396a 100644 --- a/man/RunWOT.Rd +++ b/man/RunWOT.Rd @@ -15,7 +15,7 @@ RunWOT( time_field = "Time", growth_iters = 3L, tmap_out = "tmaps/tmap_out", - time_from = 1, + time_from = NULL, time_to = NULL, get_coupling = FALSE, recalculate = FALSE, @@ -46,9 +46,6 @@ RunWOT( \item{return_seurat}{} } -\value{ -A \code{anndata} object. -} \description{ Run WOT analysis }