Skip to content

Commit

Permalink
Add RunWOT function
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghao-njmu committed Sep 11, 2023
1 parent cb0b522 commit 709f116
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 22 deletions.
4 changes: 3 additions & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
^LICENSE\.md$
^SCExplorer$
^README$
^README.md$
^README.md$
^tmaps$
^src/.*[^.cpp]$
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
SCExplorer
docs
renv
renv.lock
renv.lock
tmaps
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ export(RunSlingshot)
export(RunSymphonyMap)
export(RunTriMap)
export(RunUMAP2)
export(RunWOT)
export(SankeyPlot)
export(Scanorama_integrate)
export(Seurat_integrate)
Expand Down
25 changes: 15 additions & 10 deletions R/SCP-analysis.R
Original file line number Diff line number Diff line change
Expand Up @@ -5340,7 +5340,7 @@ maxDepth <- function(x, depth = 0) {
#' @importFrom reticulate py_to_r
#' @export
check_python_element <- function(x, depth = maxDepth(x)) {
if (depth == 0 || !is.list(x)) {
if (depth == 0 || !is.list(x) || !inherits(x, "python.builtin.object")) {
if (inherits(x, "python.builtin.object")) {
x_r <- tryCatch(py_to_r(x), error = identity)
if (inherits(x_r, "error")) {
Expand Down Expand Up @@ -5702,27 +5702,32 @@ RunPalantir <- function(srt = NULL, assay_X = "RNA", slot_X = "counts", assay_la
#' @inheritParams RunSCVELO
#'
#' @examples
#' @return A \code{anndata} object.
#' @export
RunWOT <- function(srt = NULL, assay_X = "RNA", slot_X = "counts", assay_layers = c("spliced", "unspliced"), slot_layers = "counts",
adata = NULL, group_by = NULL,
time_field = "Time", growth_iters = 3L, tmap_out = "tmaps/tmap_out",
time_from = 1, time_to = NULL, get_coupling = FALSE, recalculate = FALSE,
time_from = NULL, time_to = NULL, get_coupling = FALSE, recalculate = FALSE,
palette = "Paired", palcolor = NULL,
show_plot = TRUE, dpi = 300, save = FALSE, dirpath = "./", fileprefix = "",
return_seurat = !is.null(srt)) {
check_Python("wot")
if (all(is.null(srt), is.null(adata))) {
stop("One of 'srt', 'adata' must be provided.")
}
if (is.null(group_by) && any(!is.null(early_group), !is.null(terminal_groups))) {
stop("'group_by' must be provided when early_group or terminal_groups provided.")
if (is.null(group_by)) {
stop("'group_by' must be provided.")
}
if (is.null(linear_reduction) && is.null(nonlinear_reduction)) {
stop("'linear_reduction' or 'nonlinear_reduction' must be provided at least one.")
if (is.null(time_field)) {
stop("'time_field' must be provided.")
}
if (is.null(early_cell) && is.null(early_group)) {
stop("'early_cell' or 'early_group' must be provided.")
if (is.null(time_from)) {
stop("'time_from' must be provided.")
}
if (isTRUE(get_coupling) && is.null(time_to)) {

}
warning("The 'get_coupling' paramter is only valid when 'time_to' is specified.")

args <- mget(names(formals()))
args <- lapply(args, function(x) {
if (is.numeric(x)) {
Expand Down Expand Up @@ -5755,7 +5760,7 @@ RunWOT <- function(srt = NULL, assay_X = "RNA", slot_X = "counts", assay_layers
args[["palette"]] <- palette_scp(levels(groups) %||% unique(groups), palette = palette, palcolor = palcolor)

SCP_analysis <- reticulate::import_from_path("SCP_analysis", path = system.file("python", package = "SCP", mustWork = TRUE), convert = TRUE)
adata <- do.call(SCP_analysis$WOT, args)
adata <- do.call(WOT, args)

if (isTRUE(return_seurat)) {
srt_out <- adata_to_srt(adata)
Expand Down
2 changes: 1 addition & 1 deletion R/SCP-workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ check_srtList <- function(srtList, batch = NULL, assay = NULL,
}
}
if (status == "unknown") {
warning("Can not determine whether data ", i, " is log-normalized...\n", immediate. = TRUE)
warning("Can not determine whether data ", i, " is log-normalized...", immediate. = TRUE)
}
}
if (is.null(HVF)) {
Expand Down
27 changes: 22 additions & 5 deletions inst/python/SCP_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,14 +770,15 @@ def Palantir(adata=None, h5ad=None,group_by=None,palette=None,

def WOT(adata=None, h5ad=None, group_by=None,palette=None,
time_field = "Time", growth_iters = 3, tmap_out = "tmaps/tmap_out",
time_from = 1, time_to = None, get_coupling = False,recalculate=False,
time_from = None, time_to = None, get_coupling = False,recalculate=False,
show_plot=True, dpi=300, save=False, dirpath="./", fileprefix=""):
import matplotlib.pyplot as plt
import scanpy as sc
import numpy as np
import statistics
import pandas as pd
from math import hypot
import wot

import warnings
warnings.simplefilter("ignore", category=UserWarning)
Expand Down Expand Up @@ -810,6 +811,10 @@ def WOT(adata=None, h5ad=None, group_by=None,palette=None,
if group_by is None:
print("group_by must be provided.")
exit()

if time_field is None:
print("time_field must be provided.")
exit()

adata.obs[group_by] = adata.obs[group_by].astype(dtype = "category")
if pd.api.types.is_categorical_dtype(adata.obs[time_field]):
Expand All @@ -823,15 +828,19 @@ def WOT(adata=None, h5ad=None, group_by=None,palette=None,
adata.obs["time_field"] = adata.obs[time_field]

time_dict = dict(zip(adata.obs[time_field],adata.obs["time_field"]))
ot_model <- wot.ot.OTModel(adata, growth_iters = growth_iters, day_field = "time_field")
if time_from not in time_dict.keys():
print("'time_from' is incorrect")
exit()

ot_model = wot.ot.OTModel(adata, growth_iters = growth_iters, day_field = "time_field")

if recalculate is True:
ot_model.compute_all_transport_maps(tmap_out = tmap_out)
tmap_model = wot.tmap.TransportMapModel.from_directory(tmap_out)
else:
try:
tmap_model = wot.tmap.TransportMapModel.from_directory(tmap_out)
except ValueError:
except (FileNotFoundError, ValueError):
ot_model.compute_all_transport_maps(tmap_out = tmap_out)
tmap_model = wot.tmap.TransportMapModel.from_directory(tmap_out)

Expand All @@ -845,15 +854,23 @@ def WOT(adata=None, h5ad=None, group_by=None,palette=None,

trajectory_ds = tmap_model.trajectories(from_populations)
trajectory_df = pd.DataFrame(trajectory_ds.X, index=trajectory_ds.obs_names, columns=trajectory_ds.var_names)
adata.uns["trajectory_"+str(time_from)]= trajectory_df
adata.uns["trajectory_"+str(time_from)]= trajectory_df.reindex(adata.obs_names)

fates_ds = tmap_model.fates(from_populations)
fates_df = pd.DataFrame(fates_ds.X, index=fates_ds.obs_names, columns=fates_ds.var_names)
adata.uns["fates_"+str(time_from)]= fates_df
existing_rows = fates_df.index.tolist()
new_rows = list(set(adata.obs_names) - set(existing_rows))
new_df = pd.DataFrame(0, index=new_rows, columns=fates_df.columns)
fates_df = pd.concat([fates_df, new_df])
adata.uns["fates_"+str(time_from)]= fates_df.reindex(adata.obs_names)

# obs_list = wot.tmap.trajectory_trends_from_trajectory(trajectory_ds = trajectory_ds, expression_ds = adata)

if time_to is not None:
if time_to not in time_dict.keys():
print("'time_to' is incorrect")
exit()

to_populations = tmap_model.population_from_cell_sets(cell_sets, at_time = time_dict[time_to])
transition_table = tmap_model.transition_table(from_populations, to_populations)
transition_df = pd.DataFrame(transition_table.X, index=transition_table.obs_names, columns=transition_table.var_names)
Expand Down
5 changes: 1 addition & 4 deletions man/RunWOT.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 709f116

Please sign in to comment.