Skip to content

Commit

Permalink
Merge pull request #351 from mims-harvard/scvi_integration
Browse files Browse the repository at this point in the history
Added scVI loader and model class
  • Loading branch information
amva13 authored Feb 16, 2025
2 parents a97865a + 2cdb0bc commit cd82d60
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 3 deletions.
20 changes: 20 additions & 0 deletions condaenv.r2arcb8f.requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
accelerate==0.33.0
cellxgene-census==1.15.0
datasets<2.20.0
dgl==1.1.3
evaluate==0.4.2
gget==0.28.4
moleculeace==3.0.0
pydantic==2.6.3
gget==0.28.4
pydantic==2.6.3
gget==0.28.4
pydantic==2.6.3
pytest==8.3.2
rdkit==2023.9.5
tiledbsoma==1.11.4
torch==2.1.1
torch_geometric==2.5.3
torchvision==0.16.1
transformers==4.43.4
yapf==0.40.2
5 changes: 3 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: tdc-conda-env
name: tdc-conda-env
channels:
- bioconda
- conda-forge
Expand All @@ -21,7 +21,7 @@ dependencies:
- tqdm=4.65.0
- pip:
- accelerate==0.33.0
- cellxgene-census==1.15.0
- cellxgene-census==1.15.0
- datasets<2.20.0
- dgl==1.1.3
- evaluate==0.4.2
Expand All @@ -34,6 +34,7 @@ dependencies:
- pydantic==2.6.3
- pytest==8.3.2
- rdkit==2023.9.5
- scvi-tools==1.2.0
- tiledbsoma==1.11.4
- torch==2.1.1
- torch_geometric==2.5.3
Expand Down
32 changes: 32 additions & 0 deletions tdc/model_server/model_loaders/scvi_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
class scVILoader():

def __init__(self):
pass

def load(self, census_version):
import requests
import os

scvi_url = f"https://cellxgene-contrib-public.s3.us-west-2.amazonaws.com/models/scvi/{census_version}/homo_sapiens/model.pt"
os.makedirs(os.path.join(os.getcwd(), 'scvi_model'), exist_ok=True)

output_path = os.path.join('scvi_model', 'model.pt')

try:
response = requests.get(scvi_url, verify=False)
if response.status_code == 404:
raise Exception(
'Census version not found, defaulting to version 2024-07-01'
)
except Exception as e:
print(e)
census_version = "2024-07-01"
scvi_url = f"https://cellxgene-contrib-public.s3.us-west-2.amazonaws.com/models/scvi/2024-07-01/homo_sapiens/model.pt"
response = requests.get(scvi_url, verify=False)

with open(output_path, "wb") as file:
file.write(response.content)

print(
f'scVI version {census_version} downloaded to {output_path} in current directory'
)
94 changes: 94 additions & 0 deletions tdc/model_server/models/scvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch.nn as nn


class scVI(nn.Module):

def __init__(self):
import scvi as scvi_package

super().__init__()
self.model = None
self.var_names = None

def forward(self, adata):
try:
self.prepare_data(adata)

except Exception as error:
print("No var names found in SCVI reference vars")
print(f"adata.var.index must include some of {self.var_names}")

# See docstring for more information. Expect input data to have index
# ordered and dimensions per scVI specifications
self.force_data_match(adata)

# getting variational autoencoder
vae_q = self.model.load_query_data(adata, self.model)
vae_q.is_trained = True

return vae_q.get_latent_representation()

def load(self):
import scvi as scvi_package
import os

from tdc.multi_pred.anndata_dataset import DataLoader
from model_server.model_loaders import scvi_loader

if not os.path.isdir("scvi_model"):
loader = scvi_loader.scVILoader()
loader.load("2024-07-01")

# adata object needed for loading model
adata = DataLoader("scvi_test_dataset",
"./data",
dataset_names=["scvi_test_dataset"],
no_convert=True).adata

# See docstring for more information. Expect index
# ordered and dimensions per scVI specifications
self.force_data_match(adata)

#instantiate SCVI model (not callable, just used to get VAE)
self.model = scvi_package.model.SCVI.load('scvi_model', adata)
self.var_names = adata.var.index

print("loaded scVI:")
print(f"{self.model}")

return self.model

def prepare_data(self, adata):
import numpy as np

assert True in np.isin(adata.var.index, self.var_names)

adata.obs["batch"] = "unassigned"
self.model.prepare_query_anndata(adata, self.model)

def force_data_match(self, adata):
'''
Input data is expected to have index ordered and dimensions as per scVI specifications.
For more information visit:
https://huggingface.co/datasets/scvi-tools/DATASET-FOR-UNIT-TESTING-1/tree/main
'''
import torch
import numpy as np

metadata = torch.load("scvi_model/model.pt",
map_location=torch.device('cpu'))

# setting indices that match
adata.var.index = metadata["attr_dict"]["registry_"][
"field_registries"]["X"]["state_registry"]["column_names"]

# Padding X so dimensions match. Need 8000 columns because scVI was trained using adata
# containing 8000 genes. This is the number of var indices extracted from metadata above.
additional_columns = np.zeros(
(adata.X.shape[0], 8000 - adata.X.shape[1]))
adata.X = np.hstack([adata.X, additional_columns])

# getting a batch name that matches
adata.obs['batch'] = metadata["attr_dict"]["registry_"][
"field_registries"]["batch"]["state_registry"][
"categorical_mapping"][0]
7 changes: 6 additions & 1 deletion tdc/model_server/tdc_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
'CYP3A4_Veith-AttentiveFP',
]

model_hub = ["Geneformer", "scGPT"]
model_hub = ["Geneformer", "scGPT", "scVI"]


class tdc_hf_interface:
Expand Down Expand Up @@ -66,6 +66,11 @@ def load(self):
AutoModel.register(ScGPTConfig, ScGPTModel)
model = AutoModel.from_pretrained("tdc/scGPT")
return model
elif self.model_name == "scVI":
from .models.scvi import scVI
model = scVI()
model.load()
return model
raise Exception("Not implemented yet!")

def load_deeppurpose(self, save_path):
Expand Down
14 changes: 14 additions & 0 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ def testGeneformerTokenizer(self):
"Geneformer ran sucessfully. Find batch embedding example here:\n {}"
.format(out[0]))

def testscVI(self):
from tdc.multi_pred.anndata_dataset import DataLoader
from tdc import tdc_hf_interface

adata = DataLoader("scvi_test_dataset",
"./data",
dataset_names=["scvi_test_dataset"],
no_convert=True).adata

scvi = tdc_hf_interface("scVI")
model = scvi.load()
output = model(adata)
print(f"scVI ran successfully. here is an ouput: {output}")

def tearDown(self):
try:
print(os.getcwd())
Expand Down

0 comments on commit cd82d60

Please sign in to comment.