Skip to content

Commit

Permalink
Merge branch 'main' into ori-2907-custom-dataloader-registry
Browse files Browse the repository at this point in the history
  • Loading branch information
ori-kron-wis authored Sep 26, 2024
2 parents e3831cb + 4199a47 commit e1837bd
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 15 deletions.
7 changes: 1 addition & 6 deletions .github/workflows/test_linux_resolution.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python: ["3.10", "3.11", "3.12"]
install-flags:
[
"--prerelease if-necessary-or-explicit",
"--resolution lowest-direct",
"--resolution lowest",
]
install-flags: ["--prerelease if-necessary-or-explicit"]

name: integration

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ repos:
)$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.5
rev: v0.6.7
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.5.0-runtime-ubuntu22.04
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04
FROM python:3.12 AS base

RUN pip install --no-cache-dir uv
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks
60 changes: 58 additions & 2 deletions docs/user_guide/background/counterfactual_prediction.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,61 @@
# Counterfactual prediction

:::{note}
This page is under construction.
Once we have trained a model to predict a variable of interest or a generative model to learn the data distribution, we are often interested in making predictions for new samples. However, predictions over test samples may not reveal exactly what the model has learned about how the input features relate to the target variable of interest. For example, we may want to answer the question: How would the model predict the expression levels of gene Y in cell Z if gene X is knocked out? Even if we do not have an data point corresponding to this scenario, we can instead perturb the input to see what the model reports.

:::{warning}
We are using the term "counterfactual prediction" here rather loosely. In particular, we are not following the rigorous definition of counterfactual prediction in the causality literature[^ref1]. While closely related in spirit, we are making counterfactual queries with statistical models to gain some insight into what the model has learned about the data distribution.
:::

:::{figure} figures/counterfactual_cartoon.svg
:align: center
:alt: Cartoon of the counterfactual prediction task across two conditions.
:class: img-fluid

Cartoon of the counterfactual prediction task across two conditions. This counterfactual prediction can be thought of as an interpolation of nearby points in the feature space originating from condition B.
:::

## Preliminaries

Suppose we have a trained model $f_\theta$ that takes in a data point $x$ (e.g., gene expression counts) and a condition $c$ (e.g., treatment group) and returns a prediction $\hat{y}$.
Each data point takes the form of a tuple $(x,c) \in \mathcal{D}$.
We can define a *counterfactual query* as a pair $(x,c')$ where $c' \neq c$,
and the respective model output as the *counterfactual prediction*, $\hat{y}' = f_\theta(x,c')$.

We separate $c$ here out from $x$ to make the counterfactual portion of the query explicit, but it can be thought of as another dimension of $x$.

## In-distribution vs. out-of-distribution

Since we are working with statistical models rather than causal models, we have to be careful when we can rely on counterfactual predictions. At a high level, if we assume the true function relating the features to the target is smooth, we can trust counterfactual predictions for queries that are similar to points in the training data.

Say we have a counterfactual query $(x,c')$, and we have data points in the training set $(x',c')$ (i.e., $\|x - x'\|$ is small).
If our model predicts the $y$ for $(x', c')$ well,
we can reasonably trust the counterfactual prediction for $(x,c')$.
Otherwise, if $(x,c')$ is very different from any point in the training data
with condition $c'$, we cannot make any guarantees about the accuracy of the counterfactual prediction.
Dimensionality reduction techniques or harmonization methods may help create more overlap between the features $x$ across the conditions, setting the stage for more reliable counterfactual predictions.

## Applications

The most direct application of counterfactual prediction in scvi-tools can be found in the `transform_batch` kwarg of the {func}`~scvi.model.SCVI.get_normalized_expression` function. In this case, we can pass in a counterfactual batch label to get a prediction of what the normalized expression would be for a cell if it were a member of that batch. This can be useful if one wants to compare cells across different batches in the gene space.

The described approach to counterfactual prediction has also been used in a variety of applications, including:
- characterizing cell-type-specific sample-level effects [^ref2]
- predicting chemical perturbation responses in different cell types [^ref2][^ref3]
- predicting infection/perturbation responses across species [^ref4]

For more details on how counterfactual prediction is used in another method implemented in scvi-tools, see the {doc}`/user_guide/models/mrvi`.

[^ref1]:
Judea Pearl. Causality. Cambridge university press, 2009.
[^ref2]:
Pierre Boyeau, Justin Hong, Adam Gayoso, Martin Kim, Jose L McFaline-Figueroa, Michael Jordan, Elham Azizi, Can Ergen, Nir Yosef (2024),
_Deep generative modeling of sample-level heterogeneity in single-cell genomics_,
[bioRxiv](https://doi.org/10.1101/2022.10.04.510898).
[^ref3]:
Mohammad Lotfollahi, Anna Klimovskaia Susmelj, Carlo De Donno, Leon Hetzel, Yuge Ji, Ignacio L Ibarra, Sanjay R Srivatsan, Mohsen Naghipourfar, Riza M Daza, Beth Martin, Jay Shendure, Jose L McFaline‐Figueroa, Pierre Boyeau, F Alexander Wolf, Nafissa Yakubova, Stephan Günnemann, Cole Trapnell, David Lopez‐Paz, Fabian J Theis (2023),
_Predicting cellular responses to complex perturbations in high‐throughput screens_,
[Molecular Systems Biology](https://doi.org/10.15252/msb.202211517).
[^ref4]:
Mohammad Lotfollahi, F Alexander Wolf, Fabian J Theis (2019),
_scGen predicts single-cell perturbation responses_,
[Nature Methods](https://doi.org/10.1038/s41592-019-0494-8).
60 changes: 60 additions & 0 deletions docs/user_guide/background/figures/counterfactual_cartoon.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ tutorials = [
"igraph",
"scikit-misc",
"scrublet",
"scib-metrics",
"scvi-tools[optional]",
"squidpy",
]
Expand Down Expand Up @@ -137,7 +138,7 @@ markers = [
src = ["src"]
line-length = 99
indent-width = 4
target-version = "py310"
target-version = "py312"

# Exclude a variety of commonly ignored directories.
exclude = [
Expand Down
2 changes: 1 addition & 1 deletion src/scvi/utils/_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from rich.console import Console
from rich.progress import track as track_base
from tqdm import tqdm as tqdm_base
from tqdm.auto import tqdm as tqdm_base

from scvi import settings

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import shutil
from distutils.dir_util import copy_tree

import pytest
from distutils.dir_util import copy_tree

import scvi
from tests.data.utils import generic_setup_adata_manager
Expand Down
5 changes: 4 additions & 1 deletion tests/hub/test_hub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ def test_hub_model_large_training_adata(request, save_path):

@pytest.mark.private
def test_hub_model_create_repo_hf(save_path: str):
from huggingface_hub import delete_repo
from huggingface_hub import delete_repo, repo_exists

if repo_exists("scvi-tools/test-scvi-create"):
delete_repo("scvi-tools/test-scvi-create", token=os.environ["HF_API_TOKEN"])

hub_model = prep_scvi_hub_model(save_path)
hub_model.push_to_huggingface_hub(
Expand Down

0 comments on commit e1837bd

Please sign in to comment.