Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
jesusCaraball0 committed Feb 15, 2025
1 parent 94f0a43 commit 261adcd
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tdc/test/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def testGeneformerPerturb(self):
input_tensor = torch.tensor(batch)
assert input_tensor.shape[0] == 512, "unexpected batch size"
assert input_tensor.shape[1] == mdim, f"unexpected gene length {mdim}"
attention_mask = torch.tensor([[t != 0 for t in cell]
for cell in batch])
attention_mask = torch.tensor([[t != 0 for t in cell] for cell in batch
])
assert input_tensor.shape[0] == attention_mask.shape[0]
assert input_tensor.shape[1] == attention_mask.shape[1]
try:
Expand Down Expand Up @@ -160,8 +160,8 @@ def testGeneformerTokenizer(self):
ctr = 0 # stop after some passes to avoid failure
for batch in input_tensor:
# build an attention mask
attention_mask = torch.tensor([[x[0] != 0, x[1] != 0]
for x in batch])
attention_mask = torch.tensor(
[[x[0] != 0, x[1] != 0] for x in batch])
outputs = model(batch,
attention_mask=attention_mask,
output_hidden_states=True)
Expand Down

0 comments on commit 261adcd

Please sign in to comment.