diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 10d57a36c3..6b2428e461 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -2646,6 +2646,7 @@ def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, Any]: url = model.rstrip("/") + "/info" else: url = f"{INFERENCE_ENDPOINT}/models/{model}/info" + response = get_session().get(url, headers=self.headers) hf_raise_for_status(response) return response.json() @@ -2680,6 +2681,7 @@ def health_check(self, model: Optional[str] = None) -> bool: "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`." ) url = model.rstrip("/") + "/health" + response = get_session().get(url, headers=self.headers) return response.status_code == 200 @@ -2719,6 +2721,7 @@ def get_model_status(self, model: Optional[str] = None) -> ModelStatus: if model.startswith("https://"): raise NotImplementedError("Model status is only available for Inference API endpoints.") url = f"{INFERENCE_ENDPOINT}/status/{model}" + response = get_session().get(url, headers=self.headers) hf_raise_for_status(response) response_data = response.json() diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 24981bcb8a..ec003fdfe8 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -2702,6 +2702,7 @@ async def get_endpoint_info(self, *, model: Optional[str] = None) -> Dict[str, A url = model.rstrip("/") + "/info" else: url = f"{INFERENCE_ENDPOINT}/models/{model}/info" + async with self._get_client_session() as client: response = await client.get(url, proxy=self.proxies) response.raise_for_status() @@ -2738,6 +2739,7 @@ async def health_check(self, model: Optional[str] = None) -> bool: "Model must be an Inference Endpoint URL. For serverless Inference API, please use `InferenceClient.get_model_status`." ) url = model.rstrip("/") + "/health" + async with self._get_client_session() as client: response = await client.get(url, proxy=self.proxies) return response.status == 200 @@ -2779,6 +2781,7 @@ async def get_model_status(self, model: Optional[str] = None) -> ModelStatus: if model.startswith("https://"): raise NotImplementedError("Model status is only available for Inference API endpoints.") url = f"{INFERENCE_ENDPOINT}/status/{model}" + async with self._get_client_session() as client: response = await client.get(url, proxy=self.proxies) response.raise_for_status()