Skip to content

Commit

Permalink
mend
Browse files Browse the repository at this point in the history
  • Loading branch information
amva13 committed Jan 10, 2025
1 parent 6d5d5d9 commit af87a78
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 30 deletions.
57 changes: 28 additions & 29 deletions tdc/model_server/models/scgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,35 +178,34 @@ def __init__(self, config):
raise Exception("unsupported embedding style")

# Modified transformer layers to use combined QKV projections
self.transformer = nn.ModuleDict({
"layers": nn.ModuleList([
nn.ModuleDict({
"self_attn": MultiheadAttentionWithBias(
config.embsize,
config.nhead,
dropout=config.dropout,
batch_first=True
),
"linear1": nn.Linear(config.embsize, config.d_hid),
"linear2": nn.Linear(config.d_hid, config.embsize),
"norm1": nn.LayerNorm(config.embsize),
"norm2": nn.LayerNorm(config.embsize),
}) for _ in range(config.nlayers)
])
})

# # Rather than combining qkv projections, mimicking gh implementation to match weights
# from torch.nn import TransformerEncoder, TransformerEncoderLayer
# self.transformer = TransformerEncoder(
# TransformerEncoderLayer(
# d_model=config.embsize,
# nhead=config.nhead,
# dim_feedforward=config.d_hid,
# dropout=config.dropout,
# batch_first=True, # just for replication
# ),
# num_layers=config.nlayers
# )
# self.transformer = nn.ModuleDict({
# "layers": nn.ModuleList([
# nn.ModuleDict({
# "self_attn": MultiheadAttentionWithBias(
# config.embsize,
# config.nhead,
# dropout=config.dropout,
# batch_first=True
# ),
# "linear1": nn.Linear(config.embsize, config.d_hid),
# "linear2": nn.Linear(config.d_hid, config.embsize),
# "norm1": nn.LayerNorm(config.embsize),
# "norm2": nn.LayerNorm(config.embsize),
# }) for _ in range(config.nlayers)
# ])
# })

from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.transformer = TransformerEncoder(
TransformerEncoderLayer(
d_model=config.embsize,
nhead=config.nhead,
dim_feedforward=config.d_hid,
dropout=config.dropout,
batch_first=True, # just for replication
),
num_layers=config.nlayers
)

# Decoder remains the same
self.expr_decoder = ExprDecoder(config.embsize, config.explicit_zero_prob)
Expand Down
2 changes: 1 addition & 1 deletion tdc/model_server/tokenizers/scgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def tokenize_batch(
import torch
genes = torch.tensor([vocab_map.get(x, 0) for x in genes],
dtype=torch.int64)
values = torch.from_numpy(values).float().to(torch.int64)
values = torch.from_numpy(values).float()
tokenized_data.append((genes, values))
return tokenized_data

Expand Down

0 comments on commit af87a78

Please sign in to comment.