diff --git a/server/lorax_server/models/flash_distilbert.py b/server/lorax_server/models/flash_distilbert.py index c2b9ff7b..d6cda8a7 100644 --- a/server/lorax_server/models/flash_distilbert.py +++ b/server/lorax_server/models/flash_distilbert.py @@ -129,7 +129,8 @@ def supports_classification(self) -> bool: def warmup(self, batch: FlashEmbeddingClassificationBatch, max_new_tokens: int) -> int | None: # Note: This is meant to 1) preallocate the memory by doing a forward pass # and then just returning the max seqlen since for embeddings we are never generating - _ = self.embed(batch) + # TODO: (magdy) add the forward pass and debug this + # _ = self.embed(batch) return batch.max_s def generate_token(self, batch: FlashEmbeddingClassificationBatch) -> None: