diff --git a/optimum_benchmark/backends/tensorrt_llm/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py index dbf003f8..322972e8 100644 --- a/optimum_benchmark/backends/tensorrt_llm/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -14,8 +14,8 @@ class TRTLLMBackend(Backend): NAME = "tensorrt-llm" - def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any]) -> None: - super().__init__(model, task, device, hub_kwargs) + def __init__(self, model: str, task: str, library: str, device: str, hub_kwargs: Dict[str, Any]): + super().__init__(model, task, library, device, hub_kwargs) self.validate_device() self.validate_model_type()