From 5b6c772db5ca25fc6c44cc60c6ffe61bbc6469b0 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 8 Jan 2025 11:37:05 -0500 Subject: [PATCH 01/17] integrate scGPT vocab into the server --- .github/workflows/conda-tests.yml | 1 + tdc/metadata.py | 2 ++ tdc/model_server/tokenizers/scgpt.py | 4 ++++ tdc/test/test_model_server.py | 8 +++++++- 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/.github/workflows/conda-tests.yml b/.github/workflows/conda-tests.yml index a26d9712..d10c4ac2 100644 --- a/.github/workflows/conda-tests.yml +++ b/.github/workflows/conda-tests.yml @@ -48,6 +48,7 @@ jobs: echo "Creating Conda Environment from environment.yml" conda env create -f environment.yml conda activate tdc-conda-env + python run_tests.py tdc.test.test_model_server.TestModelServer.testscGPT python run_tests.py yapf --style=google -r -d tdc conda deactivate diff --git a/tdc/metadata.py b/tdc/metadata.py index 98368367..ad76117e 100644 --- a/tdc/metadata.py +++ b/tdc/metadata.py @@ -957,6 +957,7 @@ def get_task2category(): "evebio_pharmone_v1_target_doc": "tab", "evebio_pharmone_v1_target_table": "tab", "cellxgene_sample_small": "h5ad", + "scgpt_vocab": "json", } name2id = { @@ -1164,6 +1165,7 @@ def get_task2category(): "evebio_pharmone_v1_target_doc": 10741536, "evebio_pharmone_v1_target_table": 10741537, "cellxgene_sample_small": 10806522, + "scgpt_vocab": 10809431, } oracle2type = { diff --git a/tdc/model_server/tokenizers/scgpt.py b/tdc/model_server/tokenizers/scgpt.py index df472e93..c77f4273 100644 --- a/tdc/model_server/tokenizers/scgpt.py +++ b/tdc/model_server/tokenizers/scgpt.py @@ -1,6 +1,8 @@ import numpy as np from typing import List, Tuple +from ...utils.load import pd_load, download_wrapper + def tokenize_batch( data: np.ndarray, @@ -23,6 +25,7 @@ def tokenize_batch( Returns: list: A list of tuple (gene_names, counts) of non zero gene expressions. """ + vocab_map = download_wrapper("scgpt_vocab") if data.shape[1] != len(gene_ids): raise ValueError( f"Number of features in data ({data.shape[1]}) does not match " @@ -43,6 +46,7 @@ def tokenize_batch( values = np.insert(values, 0, 0) if return_pt: import torch + genes = torch.tensor([vocab_map.get(x,0) for x in genes], dtype=torch.int64) values = torch.from_numpy(values).float().to(torch.int64) tokenized_data.append((genes, values)) return tokenized_data diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index 3dcf1330..e5a7bf52 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -108,7 +108,13 @@ def testscGPT(self): ) # Convert to numpy array tokenized_data = tokenizer.tokenize_cell_vectors( adata.X.toarray(), gene_ids) - first_embed = model(tokenized_data[0][1]).last_hidden_state + mask = [x!=0 for x in tokenized_data[0][1]] + assert sum(mask)!=0, "FAILURE: mask is empty" + first_embed = model( + tokenized_data[0][1], + input_ids=tokenized_data[0][0], + attention_mask=mask).last_hidden_state + print(f"scgpt ran successfully. here is an output {first_embed}") self.assertEqual(first_embed.shape[0], len(tokenized_data[0][0])) def testGeneformerTokenizer(self): From feae845dcd134a3630dcc1db6f83df9346d5b4a2 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 8 Jan 2025 12:18:48 -0500 Subject: [PATCH 02/17] download_wrapper fix and yapf --- tdc/model_server/tokenizers/scgpt.py | 5 +++-- tdc/test/test_model_server.py | 11 +++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tdc/model_server/tokenizers/scgpt.py b/tdc/model_server/tokenizers/scgpt.py index c77f4273..f6806849 100644 --- a/tdc/model_server/tokenizers/scgpt.py +++ b/tdc/model_server/tokenizers/scgpt.py @@ -25,7 +25,7 @@ def tokenize_batch( Returns: list: A list of tuple (gene_names, counts) of non zero gene expressions. """ - vocab_map = download_wrapper("scgpt_vocab") + vocab_map = download_wrapper("scgpt_vocab", "./data", ["scgpt_vocab"]) if data.shape[1] != len(gene_ids): raise ValueError( f"Number of features in data ({data.shape[1]}) does not match " @@ -46,7 +46,8 @@ def tokenize_batch( values = np.insert(values, 0, 0) if return_pt: import torch - genes = torch.tensor([vocab_map.get(x,0) for x in genes], dtype=torch.int64) + genes = torch.tensor([vocab_map.get(x, 0) for x in genes], + dtype=torch.int64) values = torch.from_numpy(values).float().to(torch.int64) tokenized_data.append((genes, values)) return tokenized_data diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index e5a7bf52..1cad844a 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -108,12 +108,11 @@ def testscGPT(self): ) # Convert to numpy array tokenized_data = tokenizer.tokenize_cell_vectors( adata.X.toarray(), gene_ids) - mask = [x!=0 for x in tokenized_data[0][1]] - assert sum(mask)!=0, "FAILURE: mask is empty" - first_embed = model( - tokenized_data[0][1], - input_ids=tokenized_data[0][0], - attention_mask=mask).last_hidden_state + mask = [x != 0 for x in tokenized_data[0][1]] + assert sum(mask) != 0, "FAILURE: mask is empty" + first_embed = model(tokenized_data[0][1], + input_ids=tokenized_data[0][0], + attention_mask=mask).last_hidden_state print(f"scgpt ran successfully. here is an output {first_embed}") self.assertEqual(first_embed.shape[0], len(tokenized_data[0][0])) From 0f6768b274c36492d95db18bb9adda653dd1d41a Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Wed, 8 Jan 2025 16:23:57 -0500 Subject: [PATCH 03/17] mend --- tdc/model_server/tokenizers/scgpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tdc/model_server/tokenizers/scgpt.py b/tdc/model_server/tokenizers/scgpt.py index f6806849..5923e7e8 100644 --- a/tdc/model_server/tokenizers/scgpt.py +++ b/tdc/model_server/tokenizers/scgpt.py @@ -25,7 +25,7 @@ def tokenize_batch( Returns: list: A list of tuple (gene_names, counts) of non zero gene expressions. """ - vocab_map = download_wrapper("scgpt_vocab", "./data", ["scgpt_vocab"]) + vocab_map = pd_load("scgpt_vocab", "./data") if data.shape[1] != len(gene_ids): raise ValueError( f"Number of features in data ({data.shape[1]}) does not match " From 3ed83adaa5aaa0f5e67ba8e17cf0ce09c6cea216 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Thu, 9 Jan 2025 03:13:52 -0500 Subject: [PATCH 04/17] mend --- tdc/model_server/tokenizers/scgpt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tdc/model_server/tokenizers/scgpt.py b/tdc/model_server/tokenizers/scgpt.py index 5923e7e8..ea450d3a 100644 --- a/tdc/model_server/tokenizers/scgpt.py +++ b/tdc/model_server/tokenizers/scgpt.py @@ -25,6 +25,7 @@ def tokenize_batch( Returns: list: A list of tuple (gene_names, counts) of non zero gene expressions. """ + download_wrapper("scgpt_vocab", "./data", ["scgpt_vocab"]) vocab_map = pd_load("scgpt_vocab", "./data") if data.shape[1] != len(gene_ids): raise ValueError( From a41cd0e2706522807fcb94fe6a599cbcd23ffc27 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Thu, 9 Jan 2025 03:24:02 -0500 Subject: [PATCH 05/17] mend --- tdc/utils/load.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tdc/utils/load.py b/tdc/utils/load.py index 4b557a3c..7acd58f4 100644 --- a/tdc/utils/load.py +++ b/tdc/utils/load.py @@ -317,7 +317,10 @@ def pd_load(name, path): file_path = os.path.join(path, name + "." + name2type[name]) with open(file_path, 'r') as f: file_content = json.load(f) - maxlen = max(len(x) for x in file_content.values()) + try: + maxlen = max(len(x) for x in file_content.values()) + except: + return file_content for k, v in file_content.items(): r = maxlen - len(v) file_content[k] = v + [None] * r From c1d800c84c3e43b42086f5e8766d92a06e2bdb38 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Thu, 9 Jan 2025 09:40:46 -0500 Subject: [PATCH 06/17] only pass the input ids --- tdc/test/test_model_server.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index 1cad844a..62889249 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -110,11 +110,9 @@ def testscGPT(self): adata.X.toarray(), gene_ids) mask = [x != 0 for x in tokenized_data[0][1]] assert sum(mask) != 0, "FAILURE: mask is empty" - first_embed = model(tokenized_data[0][1], - input_ids=tokenized_data[0][0], + first_embed = model(tokenized_data[0][0], attention_mask=mask).last_hidden_state print(f"scgpt ran successfully. here is an output {first_embed}") - self.assertEqual(first_embed.shape[0], len(tokenized_data[0][0])) def testGeneformerTokenizer(self): From 6d6277bb7e84329ed46277b4dfd836482387340f Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 07:48:07 -0500 Subject: [PATCH 07/17] scGPT new config for transformers implementation. test. --- tdc/test/test_model_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index 62889249..91cc7ddd 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -111,6 +111,7 @@ def testscGPT(self): mask = [x != 0 for x in tokenized_data[0][1]] assert sum(mask) != 0, "FAILURE: mask is empty" first_embed = model(tokenized_data[0][0], + tokenized_data[0][1], attention_mask=mask).last_hidden_state print(f"scgpt ran successfully. here is an output {first_embed}") From 6d5d5d9516d7a3f1daf482fe30d5fc0850e8184d Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 08:37:04 -0500 Subject: [PATCH 08/17] add scgpt modeling code and registration code to model hub --- tdc/model_server/models/scgpt.py | 303 +++++++++++++++++++++++++++++++ tdc/model_server/tdc_hf.py | 5 +- 2 files changed, 307 insertions(+), 1 deletion(-) create mode 100644 tdc/model_server/models/scgpt.py diff --git a/tdc/model_server/models/scgpt.py b/tdc/model_server/models/scgpt.py new file mode 100644 index 00000000..91550924 --- /dev/null +++ b/tdc/model_server/models/scgpt.py @@ -0,0 +1,303 @@ +from transformers import PretrainedConfig, PreTrainedModel +import torch.nn as nn +import torch +import torch.nn.functional as F +from typing import Optional, Dict + +class ScGPTConfig(PretrainedConfig): + model_type = "scgpt" + + def __init__( + self, + vocab_size=60697, + embsize=512, + d_hid=512, + nlayers=12, + nhead=8, + max_seq_len=1536, + dropout=0.0, + pad_token_id=0, + use_fast_transformer=True, + input_emb_style="continuous", + cell_emb_style="cls", # output embedding vector with + norm_scheme="post", + explicit_zero_prob=False, + **kwargs + ): + self.vocab_size = vocab_size + self.embsize = embsize + self.d_hid = d_hid + self.nlayers = nlayers + self.nhead = nhead + self.max_seq_len = max_seq_len + self.dropout = dropout + self.pad_token_id = pad_token_id + self.use_fast_transformer = use_fast_transformer + if input_emb_style not in ["continuous"]: + raise ValueError(f"Invalid input_emb_style: {input_emb_style}. Only continuous embeddings currently supported.") + self.input_emb_style = input_emb_style + self.cell_emb_style = cell_emb_style + self.norm_scheme = norm_scheme + self.explicit_zero_prob = explicit_zero_prob + super().__init__(pad_token_id=pad_token_id, **kwargs) + +class ExprDecoder(nn.Module): + def __init__(self, d_model: int, explicit_zero_prob: bool = False): + super().__init__() + self.fc = nn.Sequential( + nn.Linear(d_model, d_model), # we don't use batch labels + nn.LeakyReLU(), + nn.Linear(d_model, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, 1), + ) + self.explicit_zero_prob = explicit_zero_prob + if explicit_zero_prob: + self.zero_logit = nn.Sequential( + nn.Linear(d_model, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, d_model), + nn.LeakyReLU(), + nn.Linear(d_model, 1), + ) + + def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: + pred_value = self.fc(x).squeeze(-1) + if not self.explicit_zero_prob: + return {"pred": pred_value} + zero_logits = self.zero_logit(x).squeeze(-1) + zero_probs = torch.sigmoid(zero_logits) + return {"pred": pred_value, "zero_probs": zero_probs} # TODO: what about inference / bernoulli? + +class FlashTransformerEncoderLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward, dropout, norm_scheme="post"): + super().__init__() + from flash_attn.flash_attention import FlashMHA + + self.self_attn = FlashMHA( + embed_dim=d_model, + num_heads=nhead, + dropout=dropout, + attention_dropout=dropout, + ) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model) + ) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + self.norm_scheme = norm_scheme + +# Helper class to ensure we have the correct attention structure +class MultiheadAttentionWithBias(nn.Module): + def __init__(self, embed_dim, num_heads, dropout=0.0, batch_first=True): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + + # Combined input projections for Q, K, V + self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim))) + self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) + + # Output projection + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + self._reset_parameters() + + def _reset_parameters(self): + # Initialize parameters following PyTorch's MultiheadAttention initialization + nn.init.xavier_uniform_(self.in_proj_weight) + nn.init.xavier_uniform_(self.out_proj.weight) + nn.init.constant_(self.in_proj_bias, 0.) + nn.init.constant_(self.out_proj.bias, 0.) + + def forward(self, query, key, value, key_padding_mask=None): + return nn.functional.multi_head_attention_forward( + query, key, value, + self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + None, None, None, # No bias_k, bias_v, or add_zero_attn + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=False, + batch_first=self.batch_first + )[0] + +from transformers import PreTrainedModel +import torch.nn as nn + +class ScGPTPreTrainedModel(PreTrainedModel): + config_class = ScGPTConfig + base_model_prefix = "scgpt" + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=0.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + +class ScGPTModel(ScGPTPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + # Gene name embeddings remain the same + self.gene_encoder = nn.ModuleDict({ + "embedding": nn.Embedding( + config.vocab_size, + config.embsize, + padding_idx=config.pad_token_id + ), + "enc_norm": nn.LayerNorm(config.embsize) + }) + + # Value encoder remains the same + if config.input_emb_style == "continuous": + self.value_encoder = nn.ModuleDict({ + "linear1": nn.Linear(1, config.embsize), + "linear2": nn.Linear(config.embsize, config.embsize), + "norm": nn.LayerNorm(config.embsize), + "dropout": nn.Dropout(config.dropout) + }) + elif config.input_emb_style == "scaling": + self.value_encoder = nn.Identity() + raise Exception("scaling input embedding style not supported because this model was trained on continuous style") + else: + raise Exception("unsupported embedding style") + + # Modified transformer layers to use combined QKV projections + self.transformer = nn.ModuleDict({ + "layers": nn.ModuleList([ + nn.ModuleDict({ + "self_attn": MultiheadAttentionWithBias( + config.embsize, + config.nhead, + dropout=config.dropout, + batch_first=True + ), + "linear1": nn.Linear(config.embsize, config.d_hid), + "linear2": nn.Linear(config.d_hid, config.embsize), + "norm1": nn.LayerNorm(config.embsize), + "norm2": nn.LayerNorm(config.embsize), + }) for _ in range(config.nlayers) + ]) + }) + + # # Rather than combining qkv projections, mimicking gh implementation to match weights + # from torch.nn import TransformerEncoder, TransformerEncoderLayer + # self.transformer = TransformerEncoder( + # TransformerEncoderLayer( + # d_model=config.embsize, + # nhead=config.nhead, + # dim_feedforward=config.d_hid, + # dropout=config.dropout, + # batch_first=True, # just for replication + # ), + # num_layers=config.nlayers + # ) + + # Decoder remains the same + self.expr_decoder = ExprDecoder(config.embsize, config.explicit_zero_prob) + + # we ignore cls_decoder because we do not pursue classification task + # we also ignore mvc and similarity because we ignore generative tasks + + self.init_weights() + + def forward( + self, + input_ids: torch.Tensor, + values: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_cell_emb: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + Args: + input_ids: Tensor of gene indices, shape [batch_size, seq_len] + values: Tensor of expression values, shape [batch_size, seq_len] + attention_mask: Optional mask tensor, shape [batch_size, seq_len] + output_cell_emb: Whether to output cell embeddings + + Returns: + Dictionary containing: + - 'pred': Predicted expression values + - 'cell_emb': Cell embeddings (if output_cell_emb=True) + - 'zero_probs': Zero probabilities (if config.explicit_zero_prob=True) + """ + # Gene embeddings + gene_emb = self.gene_encoder["embedding"](input_ids) + gene_emb = self.gene_encoder["enc_norm"](gene_emb) + + # Value encoding + if hasattr(self, 'value_encoder'): + values = values.unsqueeze(-1) # Add feature dimension + value_emb = self.value_encoder["linear1"](values) + value_emb = self.value_encoder["activation"](value_emb) + value_emb = self.value_encoder["linear2"](value_emb) + value_emb = self.value_encoder["norm"](value_emb) + value_emb = self.value_encoder["dropout"](value_emb) + + if self.config.input_emb_style == "continuous": + hidden_states = gene_emb + value_emb + else: # "scaling", currrently not supported + hidden_states = gene_emb * value_emb + else: + hidden_states = gene_emb + + # Convert attention_mask for transformer + # Flash attention expects mask of 0s for tokens to attend to and 1s for tokens to ignore + if self.use_flash_attention and attention_mask is not None: + if attention_mask.dtype != torch.bool: + attention_mask = attention_mask.bool() + attention_mask = ~attention_mask # we assume user follows huggingface convention for the attention mask + + # Apply transformer layers + if self.use_flash_attention: + for layer in self.transformer: + hidden_states = layer( + hidden_states, + src_key_padding_mask=attention_mask + ) + else: + hidden_states = self.transformer( + hidden_states, + src_key_padding_mask=attention_mask + ) + + # Get cell embeddings if requested + output_dict = {} + if output_cell_emb: + if self.config.cell_emb_style == "cls": + cell_emb = hidden_states[:, 0] + elif self.config.cell_emb_style == "avg-pool": + cell_emb = hidden_states.mean(dim=1) + else: # w-pool + # Weighted pooling using input values as weights + weights = F.softmax(values, dim=1).unsqueeze(-1) + cell_emb = (hidden_states * weights).sum(dim=1) + output_dict['cell_emb'] = cell_emb + + # Decode expression values + decoder_output = self.expr_decoder(hidden_states) + output_dict.update(decoder_output) + + return output_dict + + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, nn.Linear): + # Additional initialization for linear layers + if module.bias is not None: + nn.init.constant_(module.bias, 0) \ No newline at end of file diff --git a/tdc/model_server/tdc_hf.py b/tdc/model_server/tdc_hf.py index 57e25473..6571fbcd 100644 --- a/tdc/model_server/tdc_hf.py +++ b/tdc/model_server/tdc_hf.py @@ -60,7 +60,10 @@ def load(self): model = AutoModelForMaskedLM.from_pretrained("tdc/Geneformer") return model elif self.model_name == "scGPT": - from transformers import AutoModel + from transformers import AutoConfig, AutoModel + from .models.scgpt import ScGPTModel, ScGPTConfig + AutoConfig.register("scgpt", ScGPTConfig) + AutoModel.register(ScGPTConfig, ScGPTModel) model = AutoModel.from_pretrained("tdc/scGPT") return model raise Exception("Not implemented yet!") From af87a7816e194e616257855729b6c35d41a3da55 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 11:55:28 -0500 Subject: [PATCH 09/17] mend --- tdc/model_server/models/scgpt.py | 57 ++++++++++++++-------------- tdc/model_server/tokenizers/scgpt.py | 2 +- 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/tdc/model_server/models/scgpt.py b/tdc/model_server/models/scgpt.py index 91550924..0c0953cd 100644 --- a/tdc/model_server/models/scgpt.py +++ b/tdc/model_server/models/scgpt.py @@ -178,35 +178,34 @@ def __init__(self, config): raise Exception("unsupported embedding style") # Modified transformer layers to use combined QKV projections - self.transformer = nn.ModuleDict({ - "layers": nn.ModuleList([ - nn.ModuleDict({ - "self_attn": MultiheadAttentionWithBias( - config.embsize, - config.nhead, - dropout=config.dropout, - batch_first=True - ), - "linear1": nn.Linear(config.embsize, config.d_hid), - "linear2": nn.Linear(config.d_hid, config.embsize), - "norm1": nn.LayerNorm(config.embsize), - "norm2": nn.LayerNorm(config.embsize), - }) for _ in range(config.nlayers) - ]) - }) - - # # Rather than combining qkv projections, mimicking gh implementation to match weights - # from torch.nn import TransformerEncoder, TransformerEncoderLayer - # self.transformer = TransformerEncoder( - # TransformerEncoderLayer( - # d_model=config.embsize, - # nhead=config.nhead, - # dim_feedforward=config.d_hid, - # dropout=config.dropout, - # batch_first=True, # just for replication - # ), - # num_layers=config.nlayers - # ) + # self.transformer = nn.ModuleDict({ + # "layers": nn.ModuleList([ + # nn.ModuleDict({ + # "self_attn": MultiheadAttentionWithBias( + # config.embsize, + # config.nhead, + # dropout=config.dropout, + # batch_first=True + # ), + # "linear1": nn.Linear(config.embsize, config.d_hid), + # "linear2": nn.Linear(config.d_hid, config.embsize), + # "norm1": nn.LayerNorm(config.embsize), + # "norm2": nn.LayerNorm(config.embsize), + # }) for _ in range(config.nlayers) + # ]) + # }) + + from torch.nn import TransformerEncoder, TransformerEncoderLayer + self.transformer = TransformerEncoder( + TransformerEncoderLayer( + d_model=config.embsize, + nhead=config.nhead, + dim_feedforward=config.d_hid, + dropout=config.dropout, + batch_first=True, # just for replication + ), + num_layers=config.nlayers + ) # Decoder remains the same self.expr_decoder = ExprDecoder(config.embsize, config.explicit_zero_prob) diff --git a/tdc/model_server/tokenizers/scgpt.py b/tdc/model_server/tokenizers/scgpt.py index ea450d3a..eee98a4f 100644 --- a/tdc/model_server/tokenizers/scgpt.py +++ b/tdc/model_server/tokenizers/scgpt.py @@ -49,7 +49,7 @@ def tokenize_batch( import torch genes = torch.tensor([vocab_map.get(x, 0) for x in genes], dtype=torch.int64) - values = torch.from_numpy(values).float().to(torch.int64) + values = torch.from_numpy(values).float() tokenized_data.append((genes, values)) return tokenized_data From d22f03389921081d1f8c9bfbb98f31ce3973ba38 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 12:47:26 -0500 Subject: [PATCH 10/17] mend --- tdc/model_server/models/scgpt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tdc/model_server/models/scgpt.py b/tdc/model_server/models/scgpt.py index 0c0953cd..a245501c 100644 --- a/tdc/model_server/models/scgpt.py +++ b/tdc/model_server/models/scgpt.py @@ -243,7 +243,8 @@ def forward( if hasattr(self, 'value_encoder'): values = values.unsqueeze(-1) # Add feature dimension value_emb = self.value_encoder["linear1"](values) - value_emb = self.value_encoder["activation"](value_emb) + if "activation" in self.value_encoder: + value_emb = self.value_encoder["activation"](value_emb) value_emb = self.value_encoder["linear2"](value_emb) value_emb = self.value_encoder["norm"](value_emb) value_emb = self.value_encoder["dropout"](value_emb) From 80806fa1b1ddc5486d73b3e1b665ab41bd07e597 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 12:52:45 -0500 Subject: [PATCH 11/17] mend --- tdc/model_server/models/scgpt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tdc/model_server/models/scgpt.py b/tdc/model_server/models/scgpt.py index a245501c..a99da037 100644 --- a/tdc/model_server/models/scgpt.py +++ b/tdc/model_server/models/scgpt.py @@ -39,6 +39,7 @@ def __init__( self.cell_emb_style = cell_emb_style self.norm_scheme = norm_scheme self.explicit_zero_prob = explicit_zero_prob + self.use_flash_attention = self.use_fast_transformer super().__init__(pad_token_id=pad_token_id, **kwargs) class ExprDecoder(nn.Module): From 92b335684ccbed11c01f7ee9203eebd9286c948e Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 13:09:59 -0500 Subject: [PATCH 12/17] mend --- tdc/model_server/models/scgpt.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tdc/model_server/models/scgpt.py b/tdc/model_server/models/scgpt.py index a99da037..176d6a80 100644 --- a/tdc/model_server/models/scgpt.py +++ b/tdc/model_server/models/scgpt.py @@ -22,6 +22,7 @@ def __init__( cell_emb_style="cls", # output embedding vector with norm_scheme="post", explicit_zero_prob=False, + use_flash_attention=True, **kwargs ): self.vocab_size = vocab_size @@ -39,7 +40,7 @@ def __init__( self.cell_emb_style = cell_emb_style self.norm_scheme = norm_scheme self.explicit_zero_prob = explicit_zero_prob - self.use_flash_attention = self.use_fast_transformer + self.use_flash_attention = self.use_fast_transformer and torch.cuda.is_available() and use_flash_attention super().__init__(pad_token_id=pad_token_id, **kwargs) class ExprDecoder(nn.Module): From 0e8308bfb7380f71fda7b39283394c5815c7a5f3 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 13:20:29 -0500 Subject: [PATCH 13/17] mend --- tdc/model_server/models/scgpt.py | 34 ++++++++++++++++---------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tdc/model_server/models/scgpt.py b/tdc/model_server/models/scgpt.py index 176d6a80..26eae58e 100644 --- a/tdc/model_server/models/scgpt.py +++ b/tdc/model_server/models/scgpt.py @@ -260,23 +260,23 @@ def forward( # Convert attention_mask for transformer # Flash attention expects mask of 0s for tokens to attend to and 1s for tokens to ignore - if self.use_flash_attention and attention_mask is not None: - if attention_mask.dtype != torch.bool: - attention_mask = attention_mask.bool() - attention_mask = ~attention_mask # we assume user follows huggingface convention for the attention mask - - # Apply transformer layers - if self.use_flash_attention: - for layer in self.transformer: - hidden_states = layer( - hidden_states, - src_key_padding_mask=attention_mask - ) - else: - hidden_states = self.transformer( - hidden_states, - src_key_padding_mask=attention_mask - ) + # if self.use_flash_attention and attention_mask is not None: + # if attention_mask.dtype != torch.bool: + # attention_mask = attention_mask.bool() + # attention_mask = ~attention_mask # we assume user follows huggingface convention for the attention mask + + # # Apply transformer layers + # if self.use_flash_attention: + # for layer in self.transformer: + # hidden_states = layer( + # hidden_states, + # src_key_padding_mask=attention_mask + # ) + # else: + hidden_states = self.transformer( + hidden_states, + src_key_padding_mask=attention_mask + ) # Get cell embeddings if requested output_dict = {} From 9e717a9e7f1ab965194262b338b8c9d942c6ee26 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 13:27:30 -0500 Subject: [PATCH 14/17] mend --- tdc/test/test_model_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index 91cc7ddd..91c20e72 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -108,7 +108,7 @@ def testscGPT(self): ) # Convert to numpy array tokenized_data = tokenizer.tokenize_cell_vectors( adata.X.toarray(), gene_ids) - mask = [x != 0 for x in tokenized_data[0][1]] + mask = torch.tensor([x != 0 for x in tokenized_data[0][1]], dtype=torch.bool) assert sum(mask) != 0, "FAILURE: mask is empty" first_embed = model(tokenized_data[0][0], tokenized_data[0][1], From a46a70f9441e3595c4aa305cf14725602b68766c Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 13:31:49 -0500 Subject: [PATCH 15/17] mend --- tdc/test/test_model_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index 91c20e72..b18119bc 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -97,6 +97,7 @@ def testscGPT(self): from tdc.multi_pred.anndata_dataset import DataLoader from tdc import tdc_hf_interface from tdc.model_server.tokenizers.scgpt import scGPTTokenizer + import torch adata = DataLoader("cellxgene_sample_small", "./data", dataset_names=["cellxgene_sample_small"], From 6ceff46bac59b05874bb6a6c17f0d9ffb2438acd Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 13:41:04 -0500 Subject: [PATCH 16/17] adjust for customized output. no BERT last_hidden_state --- tdc/model_server/models/scgpt.py | 132 ++++++++++++++++++------------- tdc/test/test_model_server.py | 5 +- 2 files changed, 82 insertions(+), 55 deletions(-) diff --git a/tdc/model_server/models/scgpt.py b/tdc/model_server/models/scgpt.py index 26eae58e..db6698a1 100644 --- a/tdc/model_server/models/scgpt.py +++ b/tdc/model_server/models/scgpt.py @@ -4,27 +4,27 @@ import torch.nn.functional as F from typing import Optional, Dict + class ScGPTConfig(PretrainedConfig): model_type = "scgpt" def __init__( - self, - vocab_size=60697, - embsize=512, - d_hid=512, - nlayers=12, - nhead=8, - max_seq_len=1536, - dropout=0.0, - pad_token_id=0, - use_fast_transformer=True, - input_emb_style="continuous", - cell_emb_style="cls", # output embedding vector with - norm_scheme="post", - explicit_zero_prob=False, - use_flash_attention=True, - **kwargs - ): + self, + vocab_size=60697, + embsize=512, + d_hid=512, + nlayers=12, + nhead=8, + max_seq_len=1536, + dropout=0.0, + pad_token_id=0, + use_fast_transformer=True, + input_emb_style="continuous", + cell_emb_style="cls", # output embedding vector with + norm_scheme="post", + explicit_zero_prob=False, + use_flash_attention=True, + **kwargs): self.vocab_size = vocab_size self.embsize = embsize self.d_hid = d_hid @@ -35,19 +35,24 @@ def __init__( self.pad_token_id = pad_token_id self.use_fast_transformer = use_fast_transformer if input_emb_style not in ["continuous"]: - raise ValueError(f"Invalid input_emb_style: {input_emb_style}. Only continuous embeddings currently supported.") + raise ValueError( + f"Invalid input_emb_style: {input_emb_style}. Only continuous embeddings currently supported." + ) self.input_emb_style = input_emb_style self.cell_emb_style = cell_emb_style self.norm_scheme = norm_scheme self.explicit_zero_prob = explicit_zero_prob - self.use_flash_attention = self.use_fast_transformer and torch.cuda.is_available() and use_flash_attention + self.use_flash_attention = self.use_fast_transformer and torch.cuda.is_available( + ) and use_flash_attention super().__init__(pad_token_id=pad_token_id, **kwargs) + class ExprDecoder(nn.Module): + def __init__(self, d_model: int, explicit_zero_prob: bool = False): super().__init__() self.fc = nn.Sequential( - nn.Linear(d_model, d_model), # we don't use batch labels + nn.Linear(d_model, d_model), # we don't use batch labels nn.LeakyReLU(), nn.Linear(d_model, d_model), nn.LeakyReLU(), @@ -69,10 +74,20 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: return {"pred": pred_value} zero_logits = self.zero_logit(x).squeeze(-1) zero_probs = torch.sigmoid(zero_logits) - return {"pred": pred_value, "zero_probs": zero_probs} # TODO: what about inference / bernoulli? + return { + "pred": pred_value, + "zero_probs": zero_probs + } # TODO: what about inference / bernoulli? + class FlashTransformerEncoderLayer(nn.Module): - def __init__(self, d_model, nhead, dim_feedforward, dropout, norm_scheme="post"): + + def __init__(self, + d_model, + nhead, + dim_feedforward, + dropout, + norm_scheme="post"): super().__init__() from flash_attn.flash_attention import FlashMHA @@ -82,19 +97,18 @@ def __init__(self, d_model, nhead, dim_feedforward, dropout, norm_scheme="post") dropout=dropout, attention_dropout=dropout, ) - self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model) - ) + self.feed_forward = nn.Sequential(nn.Linear(d_model, dim_feedforward), + nn.GELU(), nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model)) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.norm_scheme = norm_scheme + # Helper class to ensure we have the correct attention structure class MultiheadAttentionWithBias(nn.Module): + def __init__(self, embed_dim, num_heads, dropout=0.0, batch_first=True): super().__init__() self.embed_dim = embed_dim @@ -103,7 +117,8 @@ def __init__(self, embed_dim, num_heads, dropout=0.0, batch_first=True): self.batch_first = batch_first # Combined input projections for Q, K, V - self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim))) + self.in_proj_weight = nn.Parameter( + torch.empty((3 * embed_dim, embed_dim))) self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) # Output projection @@ -120,20 +135,29 @@ def _reset_parameters(self): def forward(self, query, key, value, key_padding_mask=None): return nn.functional.multi_head_attention_forward( - query, key, value, - self.embed_dim, self.num_heads, - self.in_proj_weight, self.in_proj_bias, - None, None, None, # No bias_k, bias_v, or add_zero_attn - self.dropout, self.out_proj.weight, self.out_proj.bias, + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + None, + None, + None, # No bias_k, bias_v, or add_zero_attn + self.dropout, + self.out_proj.weight, + self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=False, - batch_first=self.batch_first - )[0] - + batch_first=self.batch_first)[0] + + from transformers import PreTrainedModel import torch.nn as nn + class ScGPTPreTrainedModel(PreTrainedModel): config_class = ScGPTConfig base_model_prefix = "scgpt" @@ -151,18 +175,20 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + class ScGPTModel(ScGPTPreTrainedModel): + def __init__(self, config): super().__init__(config) # Gene name embeddings remain the same self.gene_encoder = nn.ModuleDict({ - "embedding": nn.Embedding( - config.vocab_size, - config.embsize, - padding_idx=config.pad_token_id - ), - "enc_norm": nn.LayerNorm(config.embsize) + "embedding": + nn.Embedding(config.vocab_size, + config.embsize, + padding_idx=config.pad_token_id), + "enc_norm": + nn.LayerNorm(config.embsize) }) # Value encoder remains the same @@ -175,7 +201,9 @@ def __init__(self, config): }) elif config.input_emb_style == "scaling": self.value_encoder = nn.Identity() - raise Exception("scaling input embedding style not supported because this model was trained on continuous style") + raise Exception( + "scaling input embedding style not supported because this model was trained on continuous style" + ) else: raise Exception("unsupported embedding style") @@ -204,13 +232,13 @@ def __init__(self, config): nhead=config.nhead, dim_feedforward=config.d_hid, dropout=config.dropout, - batch_first=True, # just for replication + batch_first=True, # just for replication ), - num_layers=config.nlayers - ) + num_layers=config.nlayers) # Decoder remains the same - self.expr_decoder = ExprDecoder(config.embsize, config.explicit_zero_prob) + self.expr_decoder = ExprDecoder(config.embsize, + config.explicit_zero_prob) # we ignore cls_decoder because we do not pursue classification task # we also ignore mvc and similarity because we ignore generative tasks @@ -273,10 +301,8 @@ def forward( # src_key_padding_mask=attention_mask # ) # else: - hidden_states = self.transformer( - hidden_states, - src_key_padding_mask=attention_mask - ) + hidden_states = self.transformer(hidden_states, + src_key_padding_mask=attention_mask) # Get cell embeddings if requested output_dict = {} @@ -302,4 +328,4 @@ def _init_weights(self, module): if isinstance(module, nn.Linear): # Additional initialization for linear layers if module.bias is not None: - nn.init.constant_(module.bias, 0) \ No newline at end of file + nn.init.constant_(module.bias, 0) diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index b18119bc..d17ac64b 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -109,11 +109,12 @@ def testscGPT(self): ) # Convert to numpy array tokenized_data = tokenizer.tokenize_cell_vectors( adata.X.toarray(), gene_ids) - mask = torch.tensor([x != 0 for x in tokenized_data[0][1]], dtype=torch.bool) + mask = torch.tensor([x != 0 for x in tokenized_data[0][1]], + dtype=torch.bool) assert sum(mask) != 0, "FAILURE: mask is empty" first_embed = model(tokenized_data[0][0], tokenized_data[0][1], - attention_mask=mask).last_hidden_state + attention_mask=mask) print(f"scgpt ran successfully. here is an output {first_embed}") def testGeneformerTokenizer(self): From 9efb8466e06f9444b1cfe681b9f5ab3c25439847 Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Fri, 10 Jan 2025 13:44:55 -0500 Subject: [PATCH 17/17] mend --- .github/workflows/conda-tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/conda-tests.yml b/.github/workflows/conda-tests.yml index d10c4ac2..a26d9712 100644 --- a/.github/workflows/conda-tests.yml +++ b/.github/workflows/conda-tests.yml @@ -48,7 +48,6 @@ jobs: echo "Creating Conda Environment from environment.yml" conda env create -f environment.yml conda activate tdc-conda-env - python run_tests.py tdc.test.test_model_server.TestModelServer.testscGPT python run_tests.py yapf --style=google -r -d tdc conda deactivate