diff --git a/tdc/test/test_model_server.py b/tdc/test/test_model_server.py index 75b3625e..e49b5ce3 100644 --- a/tdc/test/test_model_server.py +++ b/tdc/test/test_model_server.py @@ -84,7 +84,7 @@ def testGeneformerPerturb(self): assert input_tensor.shape[0] == 512, "unexpected batch size" assert input_tensor.shape[1] == mdim, f"unexpected gene length {mdim}" attention_mask = torch.tensor([[t != 0 for t in cell] for cell in batch - ]) + ]) assert input_tensor.shape[0] == attention_mask.shape[0] assert input_tensor.shape[1] == attention_mask.shape[1] try: