Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Geneformer server #326

Merged
merged 27 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies:
- torchvision==0.16.1
- transformers==4.43.4
- yapf==0.40.2
- git+https://github.com/amva13/geneformer.git@main#egg=geneformer
- git+https://huggingface.co/ctheodoris/Geneformer.git@main#egg=geneformer

variables:
KMP_DUPLICATE_LIB_OK: "TRUE"
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ tiledbsoma>=1.7.2,<2.0.0
yapf>=0.40.2,<1.0.0

# github packages
git+https://github.com/amva13/geneformer.git@main#egg=geneformer
git+https://huggingface.co/ctheodoris/Geneformer.git@main#egg=geneformer
12 changes: 6 additions & 6 deletions tdc/model_server/tokenizers/geneformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np
import scipy.sparse as sp

from geneformer import TranscriptomeTokenizer
from ...utils.load import pd_load, download_wrapper


class GeneformerTokenizer(TranscriptomeTokenizer):
class GeneformerTokenizer:
"""
Uses Geneformer Utils to parse zero-shot model server requests for tokenizing single-cell gene expression data.

Expand Down Expand Up @@ -53,7 +52,8 @@ def tokenize_cell_vectors(self,
cell_vector_adata,
target_sum=10_000,
chunk_size=512,
ensembl_id="ensembl_id"):
ensembl_id="ensembl_id",
ncounts="ncounts"):
"""
Tokenizing single-cell gene expression vectors formatted as anndata types.

Expand Down Expand Up @@ -96,16 +96,16 @@ def tokenize_cell_vectors(self,
for i in range(0, len(filter_pass_loc), chunk_size):
idx = filter_pass_loc[i:i + chunk_size]

n_counts = adata[idx].obs['ncounts'].values[:, None]
n_counts = adata[idx].obs[ncounts].values[:, None]
X_view = adata[idx, coding_miRNA_loc].X
X_norm = (X_view / n_counts * target_sum / norm_factor_vector)
X_norm = sp.csr_matrix(X_norm)

tokenized_cells += [
tokenized_cells.append([
self.rank_genes(X_norm[i].data,
coding_miRNA_tokens[X_norm[i].indices])
for i in range(X_norm.shape[0])
]
])

# add custom attributes for subview to dict
if self.custom_attr_name_dict is not None:
Expand Down
108 changes: 83 additions & 25 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import unittest
import shutil
import pytest
import mygene
import numpy as np

# temporary solution for relative imports in case TDC is not installed
# if TDC is installed, no need to use the following line
Expand All @@ -19,9 +21,17 @@
import requests


def get_ensembl_id(gene_symbols):
mg = mygene.MyGeneInfo()
return mg.querymany(gene_symbols,
scopes='symbol',
fields='ensembl.gene',
species='human')


def get_target_from_chembl(chembl_id):
# Query ChEMBL API for target information
chembl_url = f"https://www.ebi.ac.uk/chembl/api/data/target/{chembl_id}.json"
chembl_url = f"https://www.ebi.ac.uk/chembl/api/data/{chembl_id}.json"
response = requests.get(chembl_url)

if response.status_code == 200:
Expand Down Expand Up @@ -76,31 +86,79 @@ def setUp(self):
self.resource = cellxgene_census.CensusResource()

def testGeneformerTokenizer(self):
import anndata
from tdc.multi_pred.perturboutcome import PerturbOutcome
test_loader = PerturbOutcome(
name="scperturb_drug_AissaBenevolenskaya2021")
adata = test_loader.adata
print("swapping obs and var because scperturb violated convention...")
adata_flipped = anndata.AnnData(adata.X.T)
adata_flipped.obs = adata.var
adata_flipped.var = adata.obs
adata = adata_flipped
print("swap complete")
print("adding ensembl ids...")
adata.var["ensembl_id"] = adata.var["chembl-ID"].apply(
get_ensembl_id_from_chembl_id)
print("added ensembl_id column")

print(type(adata.var))
print(adata.var.columns)
print(type(adata.obs))
print(adata.obs.columns)
print("initializing tokenizer")

adata = self.resource.get_anndata(
var_value_filter=
"feature_id in ['ENSG00000161798', 'ENSG00000188229']",
obs_value_filter=
"sex == 'female' and cell_type in ['microglial cell', 'neuron']",
column_names={
"obs": [
"assay", "cell_type", "tissue", "tissue_general",
"suspension_type", "disease"
]
},
)
tokenizer = GeneformerTokenizer()
print("testing tokenizer")
x = tokenizer.tokenize_cell_vectors(adata)
assert x
x = tokenizer.tokenize_cell_vectors(adata,
ensembl_id="feature_id",
ncounts="n_measured_vars")
assert x[0]

# test Geneformer can serve the request
cells, _ = x
assert cells, "FAILURE: cells false-like. Value is = {}".format(cells)
assert len(cells) > 0, "FAILURE: length of cells <= 0 {}".format(cells)
from tdc import tdc_hf_interface
import torch
geneformer = tdc_hf_interface("Geneformer")
model = geneformer.load()

# using very few genes for these test cases so expecting empties... let's pad...
for idx in range(len(cells)):
x = cells[idx]
for j in range(len(x)):
v = x[j]
if len(v) < 2:
out = None
for _ in range(2 - len(v)):
if out is None:
out = np.append(v, 0) # pad with 0
else:
out = np.append(out, 0)
cells[idx][j] = out
if len(cells[idx]) < 512: # batch size
array = cells[idx]
# Calculate how many rows need to be added
n_rows_to_add = 512 - len(array)

# Create a padding array with [0, 0] for the remaining rows
padding = np.tile([0, 0], (n_rows_to_add, 1))

# Concatenate the original array with the padding array
cells[idx] = np.vstack((array, padding))

input_tensor = torch.tensor(cells)
out = []
try:
ctr = 0 # stop after some passes to avoid failure
for batch in input_tensor:
# build an attention mask
attention_mask = torch.tensor(
[[x[0] != 0, x[1] != 0] for x in batch])
out.append(model(batch, attention_mask=attention_mask))
if ctr == 2:
break
ctr += 1
except Exception as e:
raise Exception(e)

assert out, "FAILURE: Geneformer output is false-like. Value = {}".format(
out)
assert len(
out
) == 3, "length not matching ctr+1: {} vs {}. output was \n {}".format(
len(out), ctr + 1, out)

def tearDown(self):
try:
Expand Down
Loading