From 5d9aa754ace5d53eb90c1055dd6b1ca8e7deee4f Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Wed, 29 Nov 2023 19:02:30 -0500 Subject: [PATCH] fix(client): ensure retried requests are closed (#261) --- src/anthropic/_base_client.py | 100 +++++++++++++---- src/anthropic/_constants.py | 1 + tests/test_client.py | 198 +++++++++++++++++++++++++++++++++- 3 files changed, 278 insertions(+), 21 deletions(-) diff --git a/src/anthropic/_base_client.py b/src/anthropic/_base_client.py index a168301f..89d9ce48 100644 --- a/src/anthropic/_base_client.py +++ b/src/anthropic/_base_client.py @@ -72,6 +72,7 @@ DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, RAW_RESPONSE_HEADER, + STREAMED_RAW_RESPONSE_HEADER, ) from ._streaming import Stream, AsyncStream from ._exceptions import ( @@ -363,14 +364,21 @@ def _make_status_error_from_response( self, response: httpx.Response, ) -> APIStatusError: - err_text = response.text.strip() - body = err_text + if response.is_closed and not response.is_stream_consumed: + # We can't read the response body as it has been closed + # before it was read. This can happen if an event hook + # raises a status error. + body = None + err_msg = f"Error code: {response.status_code}" + else: + err_text = response.text.strip() + body = err_text - try: - body = json.loads(err_text) - err_msg = f"Error code: {response.status_code} - {body}" - except Exception: - err_msg = err_text or f"Error code: {response.status_code}" + try: + body = json.loads(err_text) + err_msg = f"Error code: {response.status_code} - {body}" + except Exception: + err_msg = err_text or f"Error code: {response.status_code}" return self._make_status_error(err_msg, body=body, response=response) @@ -534,6 +542,12 @@ def _process_response_data( except pydantic.ValidationError as err: raise APIResponseValidationError(response=response, body=data) from err + def _should_stream_response_body(self, *, request: httpx.Request) -> bool: + if request.headers.get(STREAMED_RAW_RESPONSE_HEADER) == "true": + return True + + return False + @property def qs(self) -> Querystring: return Querystring() @@ -606,7 +620,7 @@ def _calculate_retry_timeout( if response_headers is not None: retry_header = response_headers.get("retry-after") try: - retry_after = int(retry_header) + retry_after = float(retry_header) except Exception: retry_date_tuple = email.utils.parsedate_tz(retry_header) if retry_date_tuple is None: @@ -862,14 +876,21 @@ def _request( request = self._build_request(options) self._prepare_request(request) + response = None + try: - response = self._client.send(request, auth=self.custom_auth, stream=stream) + response = self._client.send( + request, + auth=self.custom_auth, + stream=stream or self._should_stream_response_body(request=request), + ) log.debug( 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase ) response.raise_for_status() except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code if retries > 0 and self._should_retry(err.response): + err.response.close() return self._retry_request( options, cast_to, @@ -881,9 +902,14 @@ def _request( # If the response is streamed then we need to explicitly read the response # to completion before attempting to access the response text. - err.response.read() + if not err.response.is_closed: + err.response.read() + raise self._make_status_error_from_response(err.response) from None except httpx.TimeoutException as err: + if response is not None: + response.close() + if retries > 0: return self._retry_request( options, @@ -891,9 +917,14 @@ def _request( retries, stream=stream, stream_cls=stream_cls, + response_headers=response.headers if response is not None else None, ) + raise APITimeoutError(request=request) from err except Exception as err: + if response is not None: + response.close() + if retries > 0: return self._retry_request( options, @@ -901,7 +932,9 @@ def _request( retries, stream=stream, stream_cls=stream_cls, + response_headers=response.headers if response is not None else None, ) + raise APIConnectionError(request=request) from err return self._process_response( @@ -917,7 +950,7 @@ def _retry_request( options: FinalRequestOptions, cast_to: Type[ResponseT], remaining_retries: int, - response_headers: Optional[httpx.Headers] = None, + response_headers: httpx.Headers | None, *, stream: bool, stream_cls: type[_StreamT] | None, @@ -1303,14 +1336,21 @@ async def _request( request = self._build_request(options) await self._prepare_request(request) + response = None + try: - response = await self._client.send(request, auth=self.custom_auth, stream=stream) + response = await self._client.send( + request, + auth=self.custom_auth, + stream=stream or self._should_stream_response_body(request=request), + ) log.debug( 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase ) response.raise_for_status() except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code if retries > 0 and self._should_retry(err.response): + await err.response.aclose() return await self._retry_request( options, cast_to, @@ -1322,19 +1362,39 @@ async def _request( # If the response is streamed then we need to explicitly read the response # to completion before attempting to access the response text. - await err.response.aread() + if not err.response.is_closed: + await err.response.aread() + raise self._make_status_error_from_response(err.response) from None - except httpx.ConnectTimeout as err: - if retries > 0: - return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls) - raise APITimeoutError(request=request) from err except httpx.TimeoutException as err: + if response is not None: + await response.aclose() + if retries > 0: - return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls) + return await self._retry_request( + options, + cast_to, + retries, + stream=stream, + stream_cls=stream_cls, + response_headers=response.headers if response is not None else None, + ) + raise APITimeoutError(request=request) from err except Exception as err: + if response is not None: + await response.aclose() + if retries > 0: - return await self._retry_request(options, cast_to, retries, stream=stream, stream_cls=stream_cls) + return await self._retry_request( + options, + cast_to, + retries, + stream=stream, + stream_cls=stream_cls, + response_headers=response.headers if response is not None else None, + ) + raise APIConnectionError(request=request) from err return self._process_response( @@ -1350,7 +1410,7 @@ async def _retry_request( options: FinalRequestOptions, cast_to: Type[ResponseT], remaining_retries: int, - response_headers: Optional[httpx.Headers] = None, + response_headers: httpx.Headers | None, *, stream: bool, stream_cls: type[_AsyncStreamT] | None, diff --git a/src/anthropic/_constants.py b/src/anthropic/_constants.py index 7343a7a4..4e59e854 100644 --- a/src/anthropic/_constants.py +++ b/src/anthropic/_constants.py @@ -3,6 +3,7 @@ import httpx RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response" +STREAMED_RAW_RESPONSE_HEADER = "X-Stainless-Streamed-Raw-Response" # default timeout is 10 minutes DEFAULT_TIMEOUT = httpx.Timeout(timeout=600.0, connect=5.0) diff --git a/tests/test_client.py b/tests/test_client.py index 3beacc06..6fa89caa 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -19,7 +19,12 @@ from anthropic._client import Anthropic, AsyncAnthropic from anthropic._models import BaseModel, FinalRequestOptions from anthropic._streaming import Stream, AsyncStream -from anthropic._exceptions import APIResponseValidationError +from anthropic._exceptions import ( + APIStatusError, + APITimeoutError, + APIConnectionError, + APIResponseValidationError, +) from anthropic._base_client import ( DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, @@ -39,6 +44,24 @@ def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]: return dict(url.params) +_original_response_init = cast(Any, httpx.Response.__init__) # type: ignore + + +def _low_retry_response_init(*args: Any, **kwargs: Any) -> Any: + headers = cast("list[tuple[bytes, bytes]]", kwargs["headers"]) + headers.append((b"retry-after", b"0.1")) + + return _original_response_init(*args, **kwargs) + + +def _get_open_connections(client: Anthropic | AsyncAnthropic) -> int: + transport = client._client._transport + assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport) + + pool = transport._pool + return len(pool._requests) + + class TestAnthropic: client = Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -690,6 +713,92 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str calculated = client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] + @mock.patch("httpx.Response.__init__", _low_retry_response_init) + def test_retrying_timeout_errors_doesnt_leak(self) -> None: + def raise_for_status(response: httpx.Response) -> None: + raise httpx.TimeoutException("Test timeout error", request=response.request) + + with mock.patch("httpx.Response.raise_for_status", raise_for_status): + with pytest.raises(APITimeoutError): + self.client.post( + "/v1/complete", + body=dict( + max_tokens_to_sample=300, + model="claude-2", + prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", + ), + cast_to=httpx.Response, + options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + ) + + assert _get_open_connections(self.client) == 0 + + @mock.patch("httpx.Response.__init__", _low_retry_response_init) + def test_retrying_runtime_errors_doesnt_leak(self) -> None: + def raise_for_status(_response: httpx.Response) -> None: + raise RuntimeError("Test error") + + with mock.patch("httpx.Response.raise_for_status", raise_for_status): + with pytest.raises(APIConnectionError): + self.client.post( + "/v1/complete", + body=dict( + max_tokens_to_sample=300, + model="claude-2", + prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", + ), + cast_to=httpx.Response, + options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + ) + + assert _get_open_connections(self.client) == 0 + + @mock.patch("httpx.Response.__init__", _low_retry_response_init) + def test_retrying_status_errors_doesnt_leak(self) -> None: + def raise_for_status(response: httpx.Response) -> None: + response.status_code = 500 + raise httpx.HTTPStatusError("Test 500 error", response=response, request=response.request) + + with mock.patch("httpx.Response.raise_for_status", raise_for_status): + with pytest.raises(APIStatusError): + self.client.post( + "/v1/complete", + body=dict( + max_tokens_to_sample=300, + model="claude-2", + prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", + ), + cast_to=httpx.Response, + options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + ) + + assert _get_open_connections(self.client) == 0 + + @pytest.mark.respx(base_url=base_url) + def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None: + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + def on_response(response: httpx.Response) -> None: + raise httpx.HTTPStatusError( + "Simulating an error inside httpx", + response=response, + request=response.request, + ) + + client = Anthropic( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.Client( + event_hooks={ + "response": [on_response], + } + ), + max_retries=0, + ) + with pytest.raises(APIStatusError): + client.post("/foo", cast_to=httpx.Response) + class TestAsyncAnthropic: client = AsyncAnthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -1357,3 +1466,90 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte options = FinalRequestOptions(method="get", url="/foo", max_retries=3) calculated = client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] + + @mock.patch("httpx.Response.__init__", _low_retry_response_init) + async def test_retrying_timeout_errors_doesnt_leak(self) -> None: + def raise_for_status(response: httpx.Response) -> None: + raise httpx.TimeoutException("Test timeout error", request=response.request) + + with mock.patch("httpx.Response.raise_for_status", raise_for_status): + with pytest.raises(APITimeoutError): + await self.client.post( + "/v1/complete", + body=dict( + max_tokens_to_sample=300, + model="claude-2", + prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", + ), + cast_to=httpx.Response, + options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + ) + + assert _get_open_connections(self.client) == 0 + + @mock.patch("httpx.Response.__init__", _low_retry_response_init) + async def test_retrying_runtime_errors_doesnt_leak(self) -> None: + def raise_for_status(_response: httpx.Response) -> None: + raise RuntimeError("Test error") + + with mock.patch("httpx.Response.raise_for_status", raise_for_status): + with pytest.raises(APIConnectionError): + await self.client.post( + "/v1/complete", + body=dict( + max_tokens_to_sample=300, + model="claude-2", + prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", + ), + cast_to=httpx.Response, + options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + ) + + assert _get_open_connections(self.client) == 0 + + @mock.patch("httpx.Response.__init__", _low_retry_response_init) + async def test_retrying_status_errors_doesnt_leak(self) -> None: + def raise_for_status(response: httpx.Response) -> None: + response.status_code = 500 + raise httpx.HTTPStatusError("Test 500 error", response=response, request=response.request) + + with mock.patch("httpx.Response.raise_for_status", raise_for_status): + with pytest.raises(APIStatusError): + await self.client.post( + "/v1/complete", + body=dict( + max_tokens_to_sample=300, + model="claude-2", + prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", + ), + cast_to=httpx.Response, + options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + ) + + assert _get_open_connections(self.client) == 0 + + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None: + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + def on_response(response: httpx.Response) -> None: + raise httpx.HTTPStatusError( + "Simulating an error inside httpx", + response=response, + request=response.request, + ) + + client = AsyncAnthropic( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=httpx.AsyncClient( + event_hooks={ + "response": [on_response], + } + ), + max_retries=0, + ) + with pytest.raises(APIStatusError): + await client.post("/foo", cast_to=httpx.Response)