From f2959827fe2cd555db38a62c1b3df1a12e6dee40 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Mon, 15 Jan 2024 09:09:18 -0500 Subject: [PATCH] feat(client): add support for streaming raw responses (#307) As an alternative to `with_raw_response` we now provide `with_streaming_response` as well. When using these methods you will have to use a context manager to ensure that the response is always cleaned up. --- README.md | 33 +- src/anthropic/__init__.py | 1 + src/anthropic/_base_client.py | 279 ++++++----- src/anthropic/_client.py | 16 + src/anthropic/_constants.py | 2 +- src/anthropic/_legacy_response.py | 385 +++++++++++++++ src/anthropic/_response.py | 570 ++++++++++++++++++++-- src/anthropic/_types.py | 166 +------ src/anthropic/resources/__init__.py | 22 +- src/anthropic/resources/beta/__init__.py | 22 +- src/anthropic/resources/beta/beta.py | 27 +- src/anthropic/resources/beta/messages.py | 29 +- src/anthropic/resources/completions.py | 29 +- tests/api_resources/beta/test_messages.py | 107 +++- tests/api_resources/test_completions.py | 87 +++- tests/test_client.py | 68 ++- tests/test_response.py | 50 ++ tests/utils.py | 5 + 18 files changed, 1526 insertions(+), 372 deletions(-) create mode 100644 src/anthropic/_legacy_response.py create mode 100644 tests/test_response.py diff --git a/README.md b/README.md index f424a958..cc50bf78 100644 --- a/README.md +++ b/README.md @@ -301,7 +301,7 @@ if response.my_field is None: ### Accessing raw response data (e.g. headers) -The "raw" Response object can be accessed by prefixing `.with_raw_response.` to any HTTP method call. +The "raw" Response object can be accessed by prefixing `.with_raw_response.` to any HTTP method call, e.g., ```py from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT @@ -319,7 +319,36 @@ completion = response.parse() # get the object that `completions.create()` woul print(completion.completion) ``` -These methods return an [`APIResponse`](https://github.com/anthropics/anthropic-sdk-python/tree/main/src/anthropic/_response.py) object. +These methods return an [`LegacyAPIResponse`](https://github.com/anthropics/anthropic-sdk-python/tree/main/src/anthropic/_legacy_response.py) object. This is a legacy class as we're changing it slightly in the next major version. + +For the sync client this will mostly be the same with the exception +of `content` & `text` will be methods instead of properties. In the +async client, all methods will be async. + +A migration script will be provided & the migration in general should +be smooth. + +#### `.with_streaming_response` + +The above interface eagerly reads the full response body when you make the request, which may not always be what you want. + +To stream the response body, use `.with_streaming_response` instead, which requires a context manager and only reads the response body once you call `.read()`, `.text()`, `.json()`, `.iter_bytes()`, `.iter_text()`, `.iter_lines()` or `.parse()`. In the async client, these are async methods. + +As such, `.with_streaming_response` methods return a different [`APIResponse`](https://github.com/anthropics/anthropic-sdk-python/tree/main/src/anthropic/_response.py) object, and the async client returns an [`AsyncAPIResponse`](https://github.com/anthropics/anthropic-sdk-python/tree/main/src/anthropic/_response.py) object. + +```python +with client.completions.with_streaming_response.create( + max_tokens_to_sample=300, + model="claude-2.1", + prompt=f"{HUMAN_PROMPT} Where can I get a good coffee in my neighbourhood?{AI_PROMPT}", +) as response: + print(response.headers.get("X-My-Header")) + + for line in response.iter_lines(): + print(line) +``` + +The context manager is required so that the response will reliably be closed. ### Configuring the HTTP client diff --git a/src/anthropic/__init__.py b/src/anthropic/__init__.py index c1e6bd4c..09afc2a1 100644 --- a/src/anthropic/__init__.py +++ b/src/anthropic/__init__.py @@ -15,6 +15,7 @@ RequestOptions, ) from ._version import __title__, __version__ +from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse from ._constants import AI_PROMPT as AI_PROMPT, HUMAN_PROMPT as HUMAN_PROMPT from ._exceptions import ( APIError, diff --git a/src/anthropic/_base_client.py b/src/anthropic/_base_client.py index c2c2db5f..1dfbd7df 100644 --- a/src/anthropic/_base_client.py +++ b/src/anthropic/_base_client.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import json import time import uuid @@ -31,7 +30,7 @@ overload, ) from functools import lru_cache -from typing_extensions import Literal, override +from typing_extensions import Literal, override, get_origin import anyio import httpx @@ -61,18 +60,22 @@ AsyncTransport, RequestOptions, ModelBuilderProtocol, - BinaryResponseContent, ) from ._utils import is_dict, is_given, is_mapping from ._compat import model_copy, model_dump from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type -from ._response import APIResponse +from ._response import ( + APIResponse, + BaseAPIResponse, + AsyncAPIResponse, + extract_response_type, +) from ._constants import ( DEFAULT_LIMITS, DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, RAW_RESPONSE_HEADER, - STREAMED_RAW_RESPONSE_HEADER, + OVERRIDE_CAST_TO_HEADER, ) from ._streaming import Stream, AsyncStream from ._exceptions import ( @@ -81,6 +84,7 @@ APIConnectionError, APIResponseValidationError, ) +from ._legacy_response import LegacyAPIResponse log: logging.Logger = logging.getLogger(__name__) @@ -493,28 +497,25 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o serialized[key] = value return serialized - def _process_response( - self, - *, - cast_to: Type[ResponseT], - options: FinalRequestOptions, - response: httpx.Response, - stream: bool, - stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, - ) -> ResponseT: - api_response = APIResponse( - raw=response, - client=self, - cast_to=cast_to, - stream=stream, - stream_cls=stream_cls, - options=options, - ) + def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalRequestOptions) -> type[ResponseT]: + if not is_given(options.headers): + return cast_to - if response.request.headers.get(RAW_RESPONSE_HEADER) == "true": - return cast(ResponseT, api_response) + # make a copy of the headers so we don't mutate user-input + headers = dict(options.headers) - return api_response.parse() + # we internally support defining a temporary header to override the + # default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response` + # see _response.py for implementation details + override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN) + if is_given(override_cast_to): + options.headers = headers + return cast(Type[ResponseT], override_cast_to) + + return cast_to + + def _should_stream_response_body(self, request: httpx.Request) -> bool: + return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return] def _process_response_data( self, @@ -540,12 +541,6 @@ def _process_response_data( except pydantic.ValidationError as err: raise APIResponseValidationError(response=response, body=data) from err - def _should_stream_response_body(self, *, request: httpx.Request) -> bool: - if request.headers.get(STREAMED_RAW_RESPONSE_HEADER) == "true": - return True - - return False - @property def qs(self) -> Querystring: return Querystring() @@ -610,6 +605,8 @@ def _calculate_retry_timeout( if response_headers is not None: retry_header = response_headers.get("retry-after") try: + # note: the spec indicates that this should only ever be an integer + # but if someone sends a float there's no reason for us to not respect it retry_after = float(retry_header) except Exception: retry_date_tuple = email.utils.parsedate_tz(retry_header) @@ -873,6 +870,7 @@ def _request( stream: bool, stream_cls: type[_StreamT] | None, ) -> ResponseT | _StreamT: + cast_to = self._maybe_override_cast_to(cast_to, options) self._prepare_options(options) retries = self._remaining_retries(remaining_retries, options) @@ -987,6 +985,63 @@ def _retry_request( stream_cls=stream_cls, ) + def _process_response( + self, + *, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + ) -> ResponseT: + if response.request.headers.get(RAW_RESPONSE_HEADER) == "true": + return cast( + ResponseT, + LegacyAPIResponse( + raw=response, + client=self, + cast_to=cast_to, + stream=stream, + stream_cls=stream_cls, + options=options, + ), + ) + + origin = get_origin(cast_to) or cast_to + + if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if not issubclass(origin, APIResponse): + raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}") + + response_cls = cast("type[BaseAPIResponse[Any]]", cast_to) + return cast( + ResponseT, + response_cls( + raw=response, + client=self, + cast_to=extract_response_type(response_cls), + stream=stream, + stream_cls=stream_cls, + options=options, + ), + ) + + if cast_to == httpx.Response: + return cast(ResponseT, response) + + api_response = APIResponse( + raw=response, + client=self, + cast_to=cast("type[ResponseT]", cast_to), # pyright: ignore[reportUnnecessaryCast] + stream=stream, + stream_cls=stream_cls, + options=options, + ) + if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): + return cast(ResponseT, api_response) + + return api_response.parse() + def _request_api_list( self, model: Type[object], @@ -1353,6 +1408,7 @@ async def _request( stream_cls: type[_AsyncStreamT] | None, remaining_retries: int | None, ) -> ResponseT | _AsyncStreamT: + cast_to = self._maybe_override_cast_to(cast_to, options) await self._prepare_options(options) retries = self._remaining_retries(remaining_retries, options) @@ -1428,7 +1484,7 @@ async def _request( log.debug("Re-raising status error") raise self._make_status_error_from_response(err.response) from None - return self._process_response( + return await self._process_response( cast_to=cast_to, options=options, response=response, @@ -1465,6 +1521,63 @@ async def _retry_request( stream_cls=stream_cls, ) + async def _process_response( + self, + *, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + ) -> ResponseT: + if response.request.headers.get(RAW_RESPONSE_HEADER) == "true": + return cast( + ResponseT, + LegacyAPIResponse( + raw=response, + client=self, + cast_to=cast_to, + stream=stream, + stream_cls=stream_cls, + options=options, + ), + ) + + origin = get_origin(cast_to) or cast_to + + if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if not issubclass(origin, AsyncAPIResponse): + raise TypeError(f"API Response types must subclass {AsyncAPIResponse}; Received {origin}") + + response_cls = cast("type[BaseAPIResponse[Any]]", cast_to) + return cast( + "ResponseT", + response_cls( + raw=response, + client=self, + cast_to=extract_response_type(response_cls), + stream=stream, + stream_cls=stream_cls, + options=options, + ), + ) + + if cast_to == httpx.Response: + return cast(ResponseT, response) + + api_response = AsyncAPIResponse( + raw=response, + client=self, + cast_to=cast("type[ResponseT]", cast_to), # pyright: ignore[reportUnnecessaryCast] + stream=stream, + stream_cls=stream_cls, + options=options, + ) + if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): + return cast(ResponseT, api_response) + + return await api_response.parse() + def _request_api_list( self, model: Type[_T], @@ -1783,105 +1896,3 @@ def _merge_mappings( """ merged = {**obj1, **obj2} return {key: value for key, value in merged.items() if not isinstance(value, Omit)} - - -class HttpxBinaryResponseContent(BinaryResponseContent): - response: httpx.Response - - def __init__(self, response: httpx.Response) -> None: - self.response = response - - @property - @override - def content(self) -> bytes: - return self.response.content - - @property - @override - def text(self) -> str: - return self.response.text - - @property - @override - def encoding(self) -> Optional[str]: - return self.response.encoding - - @property - @override - def charset_encoding(self) -> Optional[str]: - return self.response.charset_encoding - - @override - def json(self, **kwargs: Any) -> Any: - return self.response.json(**kwargs) - - @override - def read(self) -> bytes: - return self.response.read() - - @override - def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: - return self.response.iter_bytes(chunk_size) - - @override - def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]: - return self.response.iter_text(chunk_size) - - @override - def iter_lines(self) -> Iterator[str]: - return self.response.iter_lines() - - @override - def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: - return self.response.iter_raw(chunk_size) - - @override - def stream_to_file( - self, - file: str | os.PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - with open(file, mode="wb") as f: - for data in self.response.iter_bytes(chunk_size): - f.write(data) - - @override - def close(self) -> None: - return self.response.close() - - @override - async def aread(self) -> bytes: - return await self.response.aread() - - @override - async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: - return self.response.aiter_bytes(chunk_size) - - @override - async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]: - return self.response.aiter_text(chunk_size) - - @override - async def aiter_lines(self) -> AsyncIterator[str]: - return self.response.aiter_lines() - - @override - async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: - return self.response.aiter_raw(chunk_size) - - @override - async def astream_to_file( - self, - file: str | os.PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - path = anyio.Path(file) - async with await path.open(mode="wb") as f: - async for data in self.response.aiter_bytes(chunk_size): - await f.write(data) - - @override - async def aclose(self) -> None: - return await self.response.aclose() diff --git a/src/anthropic/_client.py b/src/anthropic/_client.py index fea68224..2c4df083 100644 --- a/src/anthropic/_client.py +++ b/src/anthropic/_client.py @@ -59,6 +59,7 @@ class Anthropic(SyncAPIClient): completions: resources.Completions beta: resources.Beta with_raw_response: AnthropicWithRawResponse + with_streaming_response: AnthropicWithStreamedResponse # client options api_key: str | None @@ -134,6 +135,7 @@ def __init__( self.completions = resources.Completions(self) self.beta = resources.Beta(self) self.with_raw_response = AnthropicWithRawResponse(self) + self.with_streaming_response = AnthropicWithStreamedResponse(self) @property @override @@ -312,6 +314,7 @@ class AsyncAnthropic(AsyncAPIClient): completions: resources.AsyncCompletions beta: resources.AsyncBeta with_raw_response: AsyncAnthropicWithRawResponse + with_streaming_response: AsyncAnthropicWithStreamedResponse # client options api_key: str | None @@ -387,6 +390,7 @@ def __init__( self.completions = resources.AsyncCompletions(self) self.beta = resources.AsyncBeta(self) self.with_raw_response = AsyncAnthropicWithRawResponse(self) + self.with_streaming_response = AsyncAnthropicWithStreamedResponse(self) @property @override @@ -573,6 +577,18 @@ def __init__(self, client: AsyncAnthropic) -> None: self.beta = resources.AsyncBetaWithRawResponse(client.beta) +class AnthropicWithStreamedResponse: + def __init__(self, client: Anthropic) -> None: + self.completions = resources.CompletionsWithStreamingResponse(client.completions) + self.beta = resources.BetaWithStreamingResponse(client.beta) + + +class AsyncAnthropicWithStreamedResponse: + def __init__(self, client: AsyncAnthropic) -> None: + self.completions = resources.AsyncCompletionsWithStreamingResponse(client.completions) + self.beta = resources.AsyncBetaWithStreamingResponse(client.beta) + + Client = Anthropic AsyncClient = AsyncAnthropic diff --git a/src/anthropic/_constants.py b/src/anthropic/_constants.py index 4e59e854..8f9d40d2 100644 --- a/src/anthropic/_constants.py +++ b/src/anthropic/_constants.py @@ -3,7 +3,7 @@ import httpx RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response" -STREAMED_RAW_RESPONSE_HEADER = "X-Stainless-Streamed-Raw-Response" +OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to" # default timeout is 10 minutes DEFAULT_TIMEOUT = httpx.Timeout(timeout=600.0, connect=5.0) diff --git a/src/anthropic/_legacy_response.py b/src/anthropic/_legacy_response.py new file mode 100644 index 00000000..b2831394 --- /dev/null +++ b/src/anthropic/_legacy_response.py @@ -0,0 +1,385 @@ +from __future__ import annotations + +import os +import inspect +import logging +import datetime +import functools +from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast +from typing_extensions import Awaitable, ParamSpec, get_args, override, deprecated, get_origin + +import anyio +import httpx + +from ._types import NoneType +from ._utils import is_given +from ._models import BaseModel, is_basemodel +from ._constants import RAW_RESPONSE_HEADER +from ._exceptions import APIResponseValidationError + +if TYPE_CHECKING: + from ._models import FinalRequestOptions + from ._base_client import Stream, BaseClient, AsyncStream + + +P = ParamSpec("P") +R = TypeVar("R") + +log: logging.Logger = logging.getLogger(__name__) + + +class LegacyAPIResponse(Generic[R]): + """This is a legacy class as it will be replaced by `APIResponse` + and `AsyncAPIResponse` in the `_response.py` file in the next major + release. + + For the sync client this will mostly be the same with the exception + of `content` & `text` will be methods instead of properties. In the + async client, all methods will be async. + + A migration script will be provided & the migration in general should + be smooth. + """ + + _cast_to: type[R] + _client: BaseClient[Any, Any] + _parsed: R | None + _stream: bool + _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None + _options: FinalRequestOptions + + http_response: httpx.Response + + def __init__( + self, + *, + raw: httpx.Response, + cast_to: type[R], + client: BaseClient[Any, Any], + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + options: FinalRequestOptions, + ) -> None: + self._cast_to = cast_to + self._client = client + self._parsed = None + self._stream = stream + self._stream_cls = stream_cls + self._options = options + self.http_response = raw + + def parse(self) -> R: + """Returns the rich python representation of this response's data. + + For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + + NOTE: For the async client: this will become a coroutine in the next major version. + """ + if self._parsed is not None: + return self._parsed + + parsed = self._parse() + if is_given(self._options.post_parser): + parsed = self._options.post_parser(parsed) + + self._parsed = parsed + return parsed + + @property + def headers(self) -> httpx.Headers: + return self.http_response.headers + + @property + def http_request(self) -> httpx.Request: + return self.http_response.request + + @property + def status_code(self) -> int: + return self.http_response.status_code + + @property + def url(self) -> httpx.URL: + return self.http_response.url + + @property + def method(self) -> str: + return self.http_request.method + + @property + def content(self) -> bytes: + """Return the binary response content. + + NOTE: this will be removed in favour of `.read()` in the + next major version. + """ + return self.http_response.content + + @property + def text(self) -> str: + """Return the decoded response content. + + NOTE: this will be turned into a method in the next major version. + """ + return self.http_response.text + + @property + def http_version(self) -> str: + return self.http_response.http_version + + @property + def is_closed(self) -> bool: + return self.http_response.is_closed + + @property + def elapsed(self) -> datetime.timedelta: + """The time taken for the complete request/response cycle to complete.""" + return self.http_response.elapsed + + def _parse(self) -> R: + if self._stream: + if self._stream_cls: + return cast( + R, + self._stream_cls( + cast_to=_extract_stream_chunk_type(self._stream_cls), + response=self.http_response, + client=cast(Any, self._client), + ), + ) + + stream_cls = cast("type[Stream[Any]] | type[AsyncStream[Any]] | None", self._client._default_stream_cls) + if stream_cls is None: + raise MissingStreamClassError() + + return cast( + R, + stream_cls( + cast_to=self._cast_to, + response=self.http_response, + client=cast(Any, self._client), + ), + ) + + cast_to = self._cast_to + if cast_to is NoneType: + return cast(R, None) + + response = self.http_response + if cast_to == str: + return cast(R, response.text) + + origin = get_origin(cast_to) or cast_to + + if inspect.isclass(origin) and issubclass(origin, HttpxBinaryResponseContent): + return cast(R, cast_to(response)) # type: ignore + + if origin == LegacyAPIResponse: + raise RuntimeError("Unexpected state - cast_to is `APIResponse`") + + if inspect.isclass(origin) and issubclass(origin, httpx.Response): + # Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response + # and pass that class to our request functions. We cannot change the variance to be either + # covariant or contravariant as that makes our usage of ResponseT illegal. We could construct + # the response class ourselves but that is something that should be supported directly in httpx + # as it would be easy to incorrectly construct the Response object due to the multitude of arguments. + if cast_to != httpx.Response: + raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`") + return cast(R, response) + + # The check here is necessary as we are subverting the the type system + # with casts as the relationship between TypeVars and Types are very strict + # which means we must return *exactly* what was input or transform it in a + # way that retains the TypeVar state. As we cannot do that in this function + # then we have to resort to using `cast`. At the time of writing, we know this + # to be safe as we have handled all the types that could be bound to the + # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then + # this function would become unsafe but a type checker would not report an error. + if ( + cast_to is not object + and not origin is list + and not origin is dict + and not origin is Union + and not issubclass(origin, BaseModel) + ): + raise RuntimeError( + f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}." + ) + + # split is required to handle cases where additional information is included + # in the response, e.g. application/json; charset=utf-8 + content_type, *_ = response.headers.get("content-type").split(";") + if content_type != "application/json": + if is_basemodel(cast_to): + try: + data = response.json() + except Exception as exc: + log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc) + else: + return self._client._process_response_data( + data=data, + cast_to=cast_to, # type: ignore + response=response, + ) + + if self._client._strict_response_validation: + raise APIResponseValidationError( + response=response, + message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", + body=response.text, + ) + + # If the API responds with content that isn't JSON then we just return + # the (decoded) text without performing any parsing so that you can still + # handle the response however you need to. + return response.text # type: ignore + + data = response.json() + + return self._client._process_response_data( + data=data, + cast_to=cast_to, # type: ignore + response=response, + ) + + @override + def __repr__(self) -> str: + return f"" + + +class MissingStreamClassError(TypeError): + def __init__(self) -> None: + super().__init__( + "The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `anthropic._streaming` for reference", + ) + + +def _extract_stream_chunk_type(stream_cls: type) -> type: + args = get_args(stream_cls) + if not args: + raise TypeError( + f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}", + ) + return cast(type, args[0]) + + +def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]: + """Higher order function that takes one of our bound API methods and wraps it + to support returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "true" + + kwargs["extra_headers"] = extra_headers + + return cast(LegacyAPIResponse[R], func(*args, **kwargs)) + + return wrapped + + +def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[LegacyAPIResponse[R]]]: + """Higher order function that takes one of our bound API methods and wraps it + to support returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "true" + + kwargs["extra_headers"] = extra_headers + + return cast(LegacyAPIResponse[R], await func(*args, **kwargs)) + + return wrapped + + +class HttpxBinaryResponseContent: + response: httpx.Response + + def __init__(self, response: httpx.Response) -> None: + self.response = response + + @property + def content(self) -> bytes: + return self.response.content + + @property + def text(self) -> str: + return self.response.text + + @property + def encoding(self) -> str | None: + return self.response.encoding + + @property + def charset_encoding(self) -> str | None: + return self.response.charset_encoding + + def json(self, **kwargs: Any) -> Any: + return self.response.json(**kwargs) + + def read(self) -> bytes: + return self.response.read() + + def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: + return self.response.iter_bytes(chunk_size) + + def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: + return self.response.iter_text(chunk_size) + + def iter_lines(self) -> Iterator[str]: + return self.response.iter_lines() + + def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]: + return self.response.iter_raw(chunk_size) + + @deprecated( + "Due to a bug, this method doesn't actually stream the response content, `.with_streaming_response.method()` should be used instead" + ) + def stream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + with open(file, mode="wb") as f: + for data in self.response.iter_bytes(chunk_size): + f.write(data) + + def close(self) -> None: + return self.response.close() + + async def aread(self) -> bytes: + return await self.response.aread() + + async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: + return self.response.aiter_bytes(chunk_size) + + async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: + return self.response.aiter_text(chunk_size) + + async def aiter_lines(self) -> AsyncIterator[str]: + return self.response.aiter_lines() + + async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: + return self.response.aiter_raw(chunk_size) + + @deprecated( + "Due to a bug, this method doesn't actually stream the response content, `.with_streaming_response.method()` should be used instead" + ) + async def astream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + path = anyio.Path(file) + async with await path.open(mode="wb") as f: + async for data in self.response.aiter_bytes(chunk_size): + await f.write(data) + + async def aclose(self) -> None: + return await self.response.aclose() diff --git a/src/anthropic/_response.py b/src/anthropic/_response.py index bfe3895b..6c3a2825 100644 --- a/src/anthropic/_response.py +++ b/src/anthropic/_response.py @@ -1,19 +1,32 @@ from __future__ import annotations +import os import inspect import logging import datetime import functools -from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Union, + Generic, + TypeVar, + Callable, + Iterator, + AsyncIterator, + cast, +) from typing_extensions import Awaitable, ParamSpec, override, get_origin +import anyio import httpx -from ._types import NoneType, BinaryResponseContent +from ._types import NoneType from ._utils import is_given, extract_type_var_from_base from ._models import BaseModel, is_basemodel -from ._constants import RAW_RESPONSE_HEADER -from ._exceptions import APIResponseValidationError +from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER +from ._exceptions import AnthropicError, APIResponseValidationError if TYPE_CHECKING: from ._models import FinalRequestOptions @@ -22,15 +35,17 @@ P = ParamSpec("P") R = TypeVar("R") +_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]") +_AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]") log: logging.Logger = logging.getLogger(__name__) -class APIResponse(Generic[R]): +class BaseAPIResponse(Generic[R]): _cast_to: type[R] _client: BaseClient[Any, Any] _parsed: R | None - _stream: bool + _is_sse_stream: bool _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None _options: FinalRequestOptions @@ -49,28 +64,18 @@ def __init__( self._cast_to = cast_to self._client = client self._parsed = None - self._stream = stream + self._is_sse_stream = stream self._stream_cls = stream_cls self._options = options self.http_response = raw - def parse(self) -> R: - if self._parsed is not None: - return self._parsed - - parsed = self._parse() - if is_given(self._options.post_parser): - parsed = self._options.post_parser(parsed) - - self._parsed = parsed - return parsed - @property def headers(self) -> httpx.Headers: return self.http_response.headers @property def http_request(self) -> httpx.Request: + """Returns the httpx Request instance associated with the current response.""" return self.http_response.request @property @@ -79,20 +84,13 @@ def status_code(self) -> int: @property def url(self) -> httpx.URL: + """Returns the URL for which the request was made.""" return self.http_response.url @property def method(self) -> str: return self.http_request.method - @property - def content(self) -> bytes: - return self.http_response.content - - @property - def text(self) -> str: - return self.http_response.text - @property def http_version(self) -> str: return self.http_response.http_version @@ -102,13 +100,29 @@ def elapsed(self) -> datetime.timedelta: """The time taken for the complete request/response cycle to complete.""" return self.http_response.elapsed + @property + def is_closed(self) -> bool: + """Whether or not the response body has been closed. + + If this is False then there is response data that has not been read yet. + You must either fully consume the response body or call `.close()` + before discarding the response to prevent resource leaks. + """ + return self.http_response.is_closed + + @override + def __repr__(self) -> str: + return ( + f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>" + ) + def _parse(self) -> R: - if self._stream: + if self._is_sse_stream: if self._stream_cls: return cast( R, self._stream_cls( - cast_to=_extract_stream_chunk_type(self._stream_cls), + cast_to=extract_stream_chunk_type(self._stream_cls), response=self.http_response, client=cast(Any, self._client), ), @@ -135,9 +149,13 @@ def _parse(self) -> R: if cast_to == str: return cast(R, response.text) + if cast_to == bytes: + return cast(R, response.content) + origin = get_origin(cast_to) or cast_to - if inspect.isclass(origin) and issubclass(origin, BinaryResponseContent): + # handle the legacy binary response case + if inspect.isclass(cast_to) and cast_to.__name__ == "HttpxBinaryResponseContent": return cast(R, cast_to(response)) # type: ignore if origin == APIResponse: @@ -208,9 +226,227 @@ def _parse(self) -> R: response=response, ) - @override - def __repr__(self) -> str: - return f"" + +class APIResponse(BaseAPIResponse[R]): + def parse(self) -> R: + """Returns the rich python representation of this response's data. + + For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + """ + if self._parsed is not None: + return self._parsed + + if not self._is_sse_stream: + self.read() + + parsed = self._parse() + if is_given(self._options.post_parser): + parsed = self._options.post_parser(parsed) + + self._parsed = parsed + return parsed + + def read(self) -> bytes: + """Read and return the binary response content.""" + try: + return self.http_response.read() + except httpx.StreamConsumed as exc: + # The default error raised by httpx isn't very + # helpful in our case so we re-raise it with + # a different error message. + raise StreamAlreadyConsumed() from exc + + def text(self) -> str: + """Read and decode the response content into a string.""" + self.read() + return self.http_response.text + + def json(self) -> object: + """Read and decode the JSON response content.""" + self.read() + return self.http_response.json() + + def close(self) -> None: + """Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + self.http_response.close() + + def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: + """ + A byte-iterator over the decoded response content. + + This automatically handles gzip, deflate and brotli encoded responses. + """ + for chunk in self.http_response.iter_bytes(chunk_size): + yield chunk + + def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: + """A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + for chunk in self.http_response.iter_text(chunk_size): + yield chunk + + def iter_lines(self) -> Iterator[str]: + """Like `iter_text()` but will only yield chunks for each line""" + for chunk in self.http_response.iter_lines(): + yield chunk + + +class AsyncAPIResponse(BaseAPIResponse[R]): + async def parse(self) -> R: + """Returns the rich python representation of this response's data. + + For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + """ + if self._parsed is not None: + return self._parsed + + if not self._is_sse_stream: + await self.read() + + parsed = self._parse() + if is_given(self._options.post_parser): + parsed = self._options.post_parser(parsed) + + self._parsed = parsed + return parsed + + async def read(self) -> bytes: + """Read and return the binary response content.""" + try: + return await self.http_response.aread() + except httpx.StreamConsumed as exc: + # the default error raised by httpx isn't very + # helpful in our case so we re-raise it with + # a different error message + raise StreamAlreadyConsumed() from exc + + async def text(self) -> str: + """Read and decode the response content into a string.""" + await self.read() + return self.http_response.text + + async def json(self) -> object: + """Read and decode the JSON response content.""" + await self.read() + return self.http_response.json() + + async def close(self) -> None: + """Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + await self.http_response.aclose() + + async def iter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: + """ + A byte-iterator over the decoded response content. + + This automatically handles gzip, deflate and brotli encoded responses. + """ + async for chunk in self.http_response.aiter_bytes(chunk_size): + yield chunk + + async def iter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: + """A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + async for chunk in self.http_response.aiter_text(chunk_size): + yield chunk + + async def iter_lines(self) -> AsyncIterator[str]: + """Like `iter_text()` but will only yield chunks for each line""" + async for chunk in self.http_response.aiter_lines(): + yield chunk + + +class BinaryAPIResponse(APIResponse[bytes]): + """Subclass of APIResponse providing helpers for dealing with binary data. + + Note: If you want to stream the response data instead of eagerly reading it + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + + def write_to_file( + self, + file: str | os.PathLike[str], + ) -> None: + """Write the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + + Note: if you want to stream the data to the file instead of writing + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + with open(file, mode="wb") as f: + for data in self.iter_bytes(): + f.write(data) + + +class AsyncBinaryAPIResponse(AsyncAPIResponse[bytes]): + """Subclass of APIResponse providing helpers for dealing with binary data. + + Note: If you want to stream the response data instead of eagerly reading it + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + + async def write_to_file( + self, + file: str | os.PathLike[str], + ) -> None: + """Write the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + + Note: if you want to stream the data to the file instead of writing + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + path = anyio.Path(file) + async with await path.open(mode="wb") as f: + async for data in self.iter_bytes(): + await f.write(data) + + +class StreamedBinaryAPIResponse(APIResponse[bytes]): + def stream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + """Streams the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + """ + with open(file, mode="wb") as f: + for data in self.iter_bytes(chunk_size): + f.write(data) + + +class AsyncStreamedBinaryAPIResponse(AsyncAPIResponse[bytes]): + async def stream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + """Streams the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + """ + path = anyio.Path(file) + async with await path.open(mode="wb") as f: + async for data in self.iter_bytes(chunk_size): + await f.write(data) class MissingStreamClassError(TypeError): @@ -220,14 +456,176 @@ def __init__(self) -> None: ) -def _extract_stream_chunk_type(stream_cls: type) -> type: - from ._base_client import Stream, AsyncStream +class StreamAlreadyConsumed(AnthropicError): + """ + Attempted to read or stream content, but the content has already + been streamed. - return extract_type_var_from_base( - stream_cls, - index=0, - generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), - ) + This can happen if you use a method like `.iter_lines()` and then attempt + to read th entire response body afterwards, e.g. + + ```py + response = await client.post(...) + async for line in response.iter_lines(): + ... # do something with `line` + + content = await response.read() + # ^ error + ``` + + If you want this behaviour you'll need to either manually accumulate the response + content or call `await response.read()` before iterating over the stream. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream some content, but the content has " + "already been streamed. " + "This could be due to attempting to stream the response " + "content more than once." + "\n\n" + "You can fix this by manually accumulating the response content while streaming " + "or by calling `.read()` before starting to stream." + ) + super().__init__(message) + + +class ResponseContextManager(Generic[_APIResponseT]): + """Context manager for ensuring that a request is not made + until it is entered and that the response will always be closed + when the context manager exits + """ + + def __init__(self, request_func: Callable[[], _APIResponseT]) -> None: + self._request_func = request_func + self.__response: _APIResponseT | None = None + + def __enter__(self) -> _APIResponseT: + self.__response = self._request_func() + return self.__response + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.__response is not None: + self.__response.close() + + +class AsyncResponseContextManager(Generic[_AsyncAPIResponseT]): + """Context manager for ensuring that a request is not made + until it is entered and that the response will always be closed + when the context manager exits + """ + + def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None: + self._api_request = api_request + self.__response: _AsyncAPIResponseT | None = None + + async def __aenter__(self) -> _AsyncAPIResponseT: + self.__response = await self._api_request + return self.__response + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.__response is not None: + await self.__response.close() + + +def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]: + """Higher order function that takes one of our bound API methods and wraps it + to support streaming and returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + + kwargs["extra_headers"] = extra_headers + + make_request = functools.partial(func, *args, **kwargs) + + return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request)) + + return wrapped + + +def async_to_streamed_response_wrapper( + func: Callable[P, Awaitable[R]], +) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]: + """Higher order function that takes one of our bound API methods and wraps it + to support streaming and returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + + kwargs["extra_headers"] = extra_headers + + make_request = func(*args, **kwargs) + + return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request)) + + return wrapped + + +def to_custom_streamed_response_wrapper( + func: Callable[P, object], + response_cls: type[_APIResponseT], +) -> Callable[P, ResponseContextManager[_APIResponseT]]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support streaming and returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + make_request = functools.partial(func, *args, **kwargs) + + return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request)) + + return wrapped + + +def async_to_custom_streamed_response_wrapper( + func: Callable[P, Awaitable[object]], + response_cls: type[_AsyncAPIResponseT], +) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support streaming and returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + make_request = func(*args, **kwargs) + + return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request)) + + return wrapped def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]: @@ -238,7 +636,7 @@ def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]] @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]: extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} - extra_headers[RAW_RESPONSE_HEADER] = "true" + extra_headers[RAW_RESPONSE_HEADER] = "raw" kwargs["extra_headers"] = extra_headers @@ -247,18 +645,102 @@ def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]: return wrapped -def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[APIResponse[R]]]: +def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]: """Higher order function that takes one of our bound API methods and wraps it to support returning the raw `APIResponse` object directly. """ @functools.wraps(func) - async def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]: + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]: extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} - extra_headers[RAW_RESPONSE_HEADER] = "true" + extra_headers[RAW_RESPONSE_HEADER] = "raw" kwargs["extra_headers"] = extra_headers - return cast(APIResponse[R], await func(*args, **kwargs)) + return cast(AsyncAPIResponse[R], await func(*args, **kwargs)) return wrapped + + +def to_custom_raw_response_wrapper( + func: Callable[P, object], + response_cls: type[_APIResponseT], +) -> Callable[P, _APIResponseT]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "raw" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + return cast(_APIResponseT, func(*args, **kwargs)) + + return wrapped + + +def async_to_custom_raw_response_wrapper( + func: Callable[P, Awaitable[object]], + response_cls: type[_AsyncAPIResponseT], +) -> Callable[P, Awaitable[_AsyncAPIResponseT]]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "raw" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs)) + + return wrapped + + +def extract_stream_chunk_type(stream_cls: type) -> type: + """Given a type like `Stream[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyStream(Stream[bytes]): + ... + + extract_stream_chunk_type(MyStream) -> bytes + ``` + """ + from ._base_client import Stream, AsyncStream + + return extract_type_var_from_base( + stream_cls, + index=0, + generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), + ) + + +def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type: + """Given a type like `APIResponse[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyResponse(APIResponse[bytes]): + ... + + extract_response_type(MyResponse) -> bytes + ``` + """ + return extract_type_var_from_base( + typ, + generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse, AsyncAPIResponse)), + index=0, + ) diff --git a/src/anthropic/_types.py b/src/anthropic/_types.py index 35cc62e1..a67d613b 100644 --- a/src/anthropic/_types.py +++ b/src/anthropic/_types.py @@ -1,7 +1,6 @@ from __future__ import annotations from os import PathLike -from abc import ABC, abstractmethod from typing import ( IO, TYPE_CHECKING, @@ -14,10 +13,8 @@ Mapping, TypeVar, Callable, - Iterator, Optional, Sequence, - AsyncIterator, ) from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable @@ -27,6 +24,8 @@ if TYPE_CHECKING: from ._models import BaseModel + from ._response import APIResponse, AsyncAPIResponse + from ._legacy_response import HttpxBinaryResponseContent Transport = BaseTransport AsyncTransport = AsyncBaseTransport @@ -37,162 +36,6 @@ _T = TypeVar("_T") -class BinaryResponseContent(ABC): - @abstractmethod - def __init__( - self, - response: Any, - ) -> None: - ... - - @property - @abstractmethod - def content(self) -> bytes: - pass - - @property - @abstractmethod - def text(self) -> str: - pass - - @property - @abstractmethod - def encoding(self) -> Optional[str]: - """ - Return an encoding to use for decoding the byte content into text. - The priority for determining this is given by... - - * `.encoding = <>` has been set explicitly. - * The encoding as specified by the charset parameter in the Content-Type header. - * The encoding as determined by `default_encoding`, which may either be - a string like "utf-8" indicating the encoding to use, or may be a callable - which enables charset autodetection. - """ - pass - - @property - @abstractmethod - def charset_encoding(self) -> Optional[str]: - """ - Return the encoding, as specified by the Content-Type header. - """ - pass - - @abstractmethod - def json(self, **kwargs: Any) -> Any: - pass - - @abstractmethod - def read(self) -> bytes: - """ - Read and return the response content. - """ - pass - - @abstractmethod - def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: - """ - A byte-iterator over the decoded response content. - This allows us to handle gzip, deflate, and brotli encoded responses. - """ - pass - - @abstractmethod - def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]: - """ - A str-iterator over the decoded response content - that handles both gzip, deflate, etc but also detects the content's - string encoding. - """ - pass - - @abstractmethod - def iter_lines(self) -> Iterator[str]: - pass - - @abstractmethod - def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: - """ - A byte-iterator over the raw response content. - """ - pass - - @abstractmethod - def stream_to_file( - self, - file: str | PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - """ - Stream the output to the given file. - """ - pass - - @abstractmethod - def close(self) -> None: - """ - Close the response and release the connection. - Automatically called if the response body is read to completion. - """ - pass - - @abstractmethod - async def aread(self) -> bytes: - """ - Read and return the response content. - """ - pass - - @abstractmethod - async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: - """ - A byte-iterator over the decoded response content. - This allows us to handle gzip, deflate, and brotli encoded responses. - """ - pass - - @abstractmethod - async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]: - """ - A str-iterator over the decoded response content - that handles both gzip, deflate, etc but also detects the content's - string encoding. - """ - pass - - @abstractmethod - async def aiter_lines(self) -> AsyncIterator[str]: - pass - - @abstractmethod - async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: - """ - A byte-iterator over the raw response content. - """ - pass - - @abstractmethod - async def astream_to_file( - self, - file: str | PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - """ - Stream the output to the given file. - """ - pass - - @abstractmethod - async def aclose(self) -> None: - """ - Close the response and release the connection. - Automatically called if the response body is read to completion. - """ - pass - - # Approximates httpx internal ProxiesTypes and RequestFiles types # while adding support for `PathLike` instances ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]] @@ -343,7 +186,9 @@ def get(self, __key: str) -> str | None: Dict[str, Any], Response, ModelBuilderProtocol, - BinaryResponseContent, + "APIResponse[Any]", + "AsyncAPIResponse[Any]", + "HttpxBinaryResponseContent", ], ) @@ -359,6 +204,7 @@ def get(self, __key: str) -> str | None: @runtime_checkable class InheritsGeneric(Protocol): """Represents a type that has inherited from `Generic` + The `__orig_bases__` property can be used to determine the resolved type variable for a given base class. """ diff --git a/src/anthropic/resources/__init__.py b/src/anthropic/resources/__init__.py index d39ac2ed..1b085e9f 100644 --- a/src/anthropic/resources/__init__.py +++ b/src/anthropic/resources/__init__.py @@ -1,15 +1,33 @@ # File generated from our OpenAPI spec by Stainless. -from .beta import Beta, AsyncBeta, BetaWithRawResponse, AsyncBetaWithRawResponse -from .completions import Completions, AsyncCompletions, CompletionsWithRawResponse, AsyncCompletionsWithRawResponse +from .beta import ( + Beta, + AsyncBeta, + BetaWithRawResponse, + AsyncBetaWithRawResponse, + BetaWithStreamingResponse, + AsyncBetaWithStreamingResponse, +) +from .completions import ( + Completions, + AsyncCompletions, + CompletionsWithRawResponse, + AsyncCompletionsWithRawResponse, + CompletionsWithStreamingResponse, + AsyncCompletionsWithStreamingResponse, +) __all__ = [ "Completions", "AsyncCompletions", "CompletionsWithRawResponse", "AsyncCompletionsWithRawResponse", + "CompletionsWithStreamingResponse", + "AsyncCompletionsWithStreamingResponse", "Beta", "AsyncBeta", "BetaWithRawResponse", "AsyncBetaWithRawResponse", + "BetaWithStreamingResponse", + "AsyncBetaWithStreamingResponse", ] diff --git a/src/anthropic/resources/beta/__init__.py b/src/anthropic/resources/beta/__init__.py index ef10de5d..663d4b95 100644 --- a/src/anthropic/resources/beta/__init__.py +++ b/src/anthropic/resources/beta/__init__.py @@ -1,15 +1,33 @@ # File generated from our OpenAPI spec by Stainless. -from .beta import Beta, AsyncBeta, BetaWithRawResponse, AsyncBetaWithRawResponse -from .messages import Messages, AsyncMessages, MessagesWithRawResponse, AsyncMessagesWithRawResponse +from .beta import ( + Beta, + AsyncBeta, + BetaWithRawResponse, + AsyncBetaWithRawResponse, + BetaWithStreamingResponse, + AsyncBetaWithStreamingResponse, +) +from .messages import ( + Messages, + AsyncMessages, + MessagesWithRawResponse, + AsyncMessagesWithRawResponse, + MessagesWithStreamingResponse, + AsyncMessagesWithStreamingResponse, +) __all__ = [ "Messages", "AsyncMessages", "MessagesWithRawResponse", "AsyncMessagesWithRawResponse", + "MessagesWithStreamingResponse", + "AsyncMessagesWithStreamingResponse", "Beta", "AsyncBeta", "BetaWithRawResponse", "AsyncBetaWithRawResponse", + "BetaWithStreamingResponse", + "AsyncBetaWithStreamingResponse", ] diff --git a/src/anthropic/resources/beta/beta.py b/src/anthropic/resources/beta/beta.py index ab63ff8b..f4201fd3 100644 --- a/src/anthropic/resources/beta/beta.py +++ b/src/anthropic/resources/beta/beta.py @@ -2,7 +2,14 @@ from __future__ import annotations -from .messages import Messages, AsyncMessages, MessagesWithRawResponse, AsyncMessagesWithRawResponse +from .messages import ( + Messages, + AsyncMessages, + MessagesWithRawResponse, + AsyncMessagesWithRawResponse, + MessagesWithStreamingResponse, + AsyncMessagesWithStreamingResponse, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource @@ -18,6 +25,10 @@ def messages(self) -> Messages: def with_raw_response(self) -> BetaWithRawResponse: return BetaWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> BetaWithStreamingResponse: + return BetaWithStreamingResponse(self) + class AsyncBeta(AsyncAPIResource): @cached_property @@ -28,6 +39,10 @@ def messages(self) -> AsyncMessages: def with_raw_response(self) -> AsyncBetaWithRawResponse: return AsyncBetaWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncBetaWithStreamingResponse: + return AsyncBetaWithStreamingResponse(self) + class BetaWithRawResponse: def __init__(self, beta: Beta) -> None: @@ -37,3 +52,13 @@ def __init__(self, beta: Beta) -> None: class AsyncBetaWithRawResponse: def __init__(self, beta: AsyncBeta) -> None: self.messages = AsyncMessagesWithRawResponse(beta.messages) + + +class BetaWithStreamingResponse: + def __init__(self, beta: Beta) -> None: + self.messages = MessagesWithStreamingResponse(beta.messages) + + +class AsyncBetaWithStreamingResponse: + def __init__(self, beta: AsyncBeta) -> None: + self.messages = AsyncMessagesWithStreamingResponse(beta.messages) diff --git a/src/anthropic/resources/beta/messages.py b/src/anthropic/resources/beta/messages.py index d6408f44..d7ed0290 100644 --- a/src/anthropic/resources/beta/messages.py +++ b/src/anthropic/resources/beta/messages.py @@ -8,11 +8,12 @@ import httpx +from ... import _legacy_response from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import required_args, maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper from ..._streaming import Stream, AsyncStream from ...types.beta import Message, MessageParam, MessageStreamEvent, message_create_params from ..._base_client import ( @@ -35,6 +36,10 @@ class Messages(SyncAPIResource): def with_raw_response(self) -> MessagesWithRawResponse: return MessagesWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> MessagesWithStreamingResponse: + return MessagesWithStreamingResponse(self) + @overload def create( self, @@ -660,6 +665,10 @@ class AsyncMessages(AsyncAPIResource): def with_raw_response(self) -> AsyncMessagesWithRawResponse: return AsyncMessagesWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncMessagesWithStreamingResponse: + return AsyncMessagesWithStreamingResponse(self) + @overload async def create( self, @@ -1281,13 +1290,27 @@ def stream( class MessagesWithRawResponse: def __init__(self, messages: Messages) -> None: - self.create = to_raw_response_wrapper( + self.create = _legacy_response.to_raw_response_wrapper( messages.create, ) class AsyncMessagesWithRawResponse: def __init__(self, messages: AsyncMessages) -> None: - self.create = async_to_raw_response_wrapper( + self.create = _legacy_response.async_to_raw_response_wrapper( + messages.create, + ) + + +class MessagesWithStreamingResponse: + def __init__(self, messages: Messages) -> None: + self.create = to_streamed_response_wrapper( + messages.create, + ) + + +class AsyncMessagesWithStreamingResponse: + def __init__(self, messages: AsyncMessages) -> None: + self.create = async_to_streamed_response_wrapper( messages.create, ) diff --git a/src/anthropic/resources/completions.py b/src/anthropic/resources/completions.py index d14052c6..2dcf146b 100644 --- a/src/anthropic/resources/completions.py +++ b/src/anthropic/resources/completions.py @@ -7,12 +7,13 @@ import httpx +from .. import _legacy_response from ..types import Completion, completion_create_params from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._utils import required_args, maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from .._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper from .._streaming import Stream, AsyncStream from .._base_client import ( make_request_options, @@ -26,6 +27,10 @@ class Completions(SyncAPIResource): def with_raw_response(self) -> CompletionsWithRawResponse: return CompletionsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> CompletionsWithStreamingResponse: + return CompletionsWithStreamingResponse(self) + @overload def create( self, @@ -361,6 +366,10 @@ class AsyncCompletions(AsyncAPIResource): def with_raw_response(self) -> AsyncCompletionsWithRawResponse: return AsyncCompletionsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncCompletionsWithStreamingResponse: + return AsyncCompletionsWithStreamingResponse(self) + @overload async def create( self, @@ -693,13 +702,27 @@ async def create( class CompletionsWithRawResponse: def __init__(self, completions: Completions) -> None: - self.create = to_raw_response_wrapper( + self.create = _legacy_response.to_raw_response_wrapper( completions.create, ) class AsyncCompletionsWithRawResponse: def __init__(self, completions: AsyncCompletions) -> None: - self.create = async_to_raw_response_wrapper( + self.create = _legacy_response.async_to_raw_response_wrapper( + completions.create, + ) + + +class CompletionsWithStreamingResponse: + def __init__(self, completions: Completions) -> None: + self.create = to_streamed_response_wrapper( + completions.create, + ) + + +class AsyncCompletionsWithStreamingResponse: + def __init__(self, completions: AsyncCompletions) -> None: + self.create = async_to_streamed_response_wrapper( completions.create, ) diff --git a/tests/api_resources/beta/test_messages.py b/tests/api_resources/beta/test_messages.py index a8adf2e3..53184f60 100644 --- a/tests/api_resources/beta/test_messages.py +++ b/tests/api_resources/beta/test_messages.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -67,13 +68,35 @@ def test_raw_response_create_overload_1(self, client: Anthropic) -> None: ], model="claude-2.1", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" message = response.parse() assert_matches_type(Message, message, path=["response"]) + @parametrize + def test_streaming_response_create_overload_1(self, client: Anthropic) -> None: + with client.beta.messages.with_streaming_response.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "In one sentence, what is good about the color blue?", + } + ], + model="claude-2.1", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + message = response.parse() + assert_matches_type(Message, message, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_create_overload_2(self, client: Anthropic) -> None: - client.beta.messages.create( + message_stream = client.beta.messages.create( max_tokens=1024, messages=[ { @@ -84,10 +107,11 @@ def test_method_create_overload_2(self, client: Anthropic) -> None: model="claude-2.1", stream=True, ) + message_stream.response.close() @parametrize def test_method_create_with_all_params_overload_2(self, client: Anthropic) -> None: - client.beta.messages.create( + message_stream = client.beta.messages.create( max_tokens=1024, messages=[ { @@ -104,6 +128,7 @@ def test_method_create_with_all_params_overload_2(self, client: Anthropic) -> No top_k=5, top_p=0.7, ) + message_stream.response.close() @parametrize def test_raw_response_create_overload_2(self, client: Anthropic) -> None: @@ -118,8 +143,31 @@ def test_raw_response_create_overload_2(self, client: Anthropic) -> None: model="claude-2.1", stream=True, ) + assert response.http_request.headers.get("X-Stainless-Lang") == "python" - response.parse() + stream = response.parse() + stream.close() + + @parametrize + def test_streaming_response_create_overload_2(self, client: Anthropic) -> None: + with client.beta.messages.with_streaming_response.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "In one sentence, what is good about the color blue?", + } + ], + model="claude-2.1", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True class TestAsyncMessages: @@ -174,13 +222,35 @@ async def test_raw_response_create_overload_1(self, client: AsyncAnthropic) -> N ], model="claude-2.1", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" message = response.parse() assert_matches_type(Message, message, path=["response"]) + @parametrize + async def test_streaming_response_create_overload_1(self, client: AsyncAnthropic) -> None: + async with client.beta.messages.with_streaming_response.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "In one sentence, what is good about the color blue?", + } + ], + model="claude-2.1", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + message = await response.parse() + assert_matches_type(Message, message, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_create_overload_2(self, client: AsyncAnthropic) -> None: - await client.beta.messages.create( + message_stream = await client.beta.messages.create( max_tokens=1024, messages=[ { @@ -191,10 +261,11 @@ async def test_method_create_overload_2(self, client: AsyncAnthropic) -> None: model="claude-2.1", stream=True, ) + await message_stream.response.aclose() @parametrize async def test_method_create_with_all_params_overload_2(self, client: AsyncAnthropic) -> None: - await client.beta.messages.create( + message_stream = await client.beta.messages.create( max_tokens=1024, messages=[ { @@ -211,6 +282,7 @@ async def test_method_create_with_all_params_overload_2(self, client: AsyncAnthr top_k=5, top_p=0.7, ) + await message_stream.response.aclose() @parametrize async def test_raw_response_create_overload_2(self, client: AsyncAnthropic) -> None: @@ -225,5 +297,28 @@ async def test_raw_response_create_overload_2(self, client: AsyncAnthropic) -> N model="claude-2.1", stream=True, ) + assert response.http_request.headers.get("X-Stainless-Lang") == "python" - response.parse() + stream = response.parse() + await stream.close() + + @parametrize + async def test_streaming_response_create_overload_2(self, client: AsyncAnthropic) -> None: + async with client.beta.messages.with_streaming_response.create( + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "In one sentence, what is good about the color blue?", + } + ], + model="claude-2.1", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_completions.py b/tests/api_resources/test_completions.py index f96563c9..6525627f 100644 --- a/tests/api_resources/test_completions.py +++ b/tests/api_resources/test_completions.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -51,22 +52,40 @@ def test_raw_response_create_overload_1(self, client: Anthropic) -> None: model="claude-2.1", prompt="\n\nHuman: Hello, world!\n\nAssistant:", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" completion = response.parse() assert_matches_type(Completion, completion, path=["response"]) + @parametrize + def test_streaming_response_create_overload_1(self, client: Anthropic) -> None: + with client.completions.with_streaming_response.create( + max_tokens_to_sample=256, + model="claude-2.1", + prompt="\n\nHuman: Hello, world!\n\nAssistant:", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + completion = response.parse() + assert_matches_type(Completion, completion, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_create_overload_2(self, client: Anthropic) -> None: - client.completions.create( + completion_stream = client.completions.create( max_tokens_to_sample=256, model="claude-2.1", prompt="\n\nHuman: Hello, world!\n\nAssistant:", stream=True, ) + completion_stream.response.close() @parametrize def test_method_create_with_all_params_overload_2(self, client: Anthropic) -> None: - client.completions.create( + completion_stream = client.completions.create( max_tokens_to_sample=256, model="claude-2.1", prompt="\n\nHuman: Hello, world!\n\nAssistant:", @@ -77,6 +96,7 @@ def test_method_create_with_all_params_overload_2(self, client: Anthropic) -> No top_k=5, top_p=0.7, ) + completion_stream.response.close() @parametrize def test_raw_response_create_overload_2(self, client: Anthropic) -> None: @@ -86,8 +106,26 @@ def test_raw_response_create_overload_2(self, client: Anthropic) -> None: prompt="\n\nHuman: Hello, world!\n\nAssistant:", stream=True, ) + assert response.http_request.headers.get("X-Stainless-Lang") == "python" - response.parse() + stream = response.parse() + stream.close() + + @parametrize + def test_streaming_response_create_overload_2(self, client: Anthropic) -> None: + with client.completions.with_streaming_response.create( + max_tokens_to_sample=256, + model="claude-2.1", + prompt="\n\nHuman: Hello, world!\n\nAssistant:", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True class TestAsyncCompletions: @@ -126,22 +164,40 @@ async def test_raw_response_create_overload_1(self, client: AsyncAnthropic) -> N model="claude-2.1", prompt="\n\nHuman: Hello, world!\n\nAssistant:", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" completion = response.parse() assert_matches_type(Completion, completion, path=["response"]) + @parametrize + async def test_streaming_response_create_overload_1(self, client: AsyncAnthropic) -> None: + async with client.completions.with_streaming_response.create( + max_tokens_to_sample=256, + model="claude-2.1", + prompt="\n\nHuman: Hello, world!\n\nAssistant:", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + completion = await response.parse() + assert_matches_type(Completion, completion, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_create_overload_2(self, client: AsyncAnthropic) -> None: - await client.completions.create( + completion_stream = await client.completions.create( max_tokens_to_sample=256, model="claude-2.1", prompt="\n\nHuman: Hello, world!\n\nAssistant:", stream=True, ) + await completion_stream.response.aclose() @parametrize async def test_method_create_with_all_params_overload_2(self, client: AsyncAnthropic) -> None: - await client.completions.create( + completion_stream = await client.completions.create( max_tokens_to_sample=256, model="claude-2.1", prompt="\n\nHuman: Hello, world!\n\nAssistant:", @@ -152,6 +208,7 @@ async def test_method_create_with_all_params_overload_2(self, client: AsyncAnthr top_k=5, top_p=0.7, ) + await completion_stream.response.aclose() @parametrize async def test_raw_response_create_overload_2(self, client: AsyncAnthropic) -> None: @@ -161,5 +218,23 @@ async def test_raw_response_create_overload_2(self, client: AsyncAnthropic) -> N prompt="\n\nHuman: Hello, world!\n\nAssistant:", stream=True, ) + assert response.http_request.headers.get("X-Stainless-Lang") == "python" - response.parse() + stream = response.parse() + await stream.close() + + @parametrize + async def test_streaming_response_create_overload_2(self, client: AsyncAnthropic) -> None: + async with client.completions.with_streaming_response.create( + max_tokens_to_sample=256, + model="claude-2.1", + prompt="\n\nHuman: Hello, world!\n\nAssistant:", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True diff --git a/tests/test_client.py b/tests/test_client.py index 53bc3cc5..531e905f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -20,6 +20,8 @@ from anthropic._types import Omit from anthropic._client import Anthropic, AsyncAnthropic from anthropic._models import BaseModel, FinalRequestOptions +from anthropic._response import APIResponse, AsyncAPIResponse +from anthropic._constants import RAW_RESPONSE_HEADER from anthropic._streaming import Stream, AsyncStream from anthropic._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError from anthropic._base_client import ( @@ -226,6 +228,7 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic # to_raw_response_wrapper leaks through the @functools.wraps() decorator. # # removing the decorator fixes the leak for reasons we don't understand. + "anthropic/_legacy_response.py", "anthropic/_response.py", # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason. "anthropic/_compat.py", @@ -719,8 +722,9 @@ class Model(BaseModel): respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = self.client.post("/foo", cast_to=Model, stream=True) - assert isinstance(response, Stream) + stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model]) + assert isinstance(stream, Stream) + stream.response.close() @pytest.mark.respx(base_url=base_url) def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: @@ -768,6 +772,29 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str calculated = client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] + @mock.patch("anthropic._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_streaming_response(self) -> None: + response = self.client.post( + "/v1/complete", + body=dict( + max_tokens_to_sample=300, + model="claude-2.1", + prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", + ), + cast_to=APIResponse[bytes], + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, + ) + + assert not cast(Any, response.is_closed) + assert _get_open_connections(self.client) == 1 + + for _ in response.iter_bytes(): + ... + + assert cast(Any, response.is_closed) + assert _get_open_connections(self.client) == 0 + @mock.patch("anthropic._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: @@ -782,7 +809,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", ), cast_to=httpx.Response, - options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) assert _get_open_connections(self.client) == 0 @@ -801,7 +828,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", ), cast_to=httpx.Response, - options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) assert _get_open_connections(self.client) == 0 @@ -982,6 +1009,7 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic # to_raw_response_wrapper leaks through the @functools.wraps() decorator. # # removing the decorator fixes the leak for reasons we don't understand. + "anthropic/_legacy_response.py", "anthropic/_response.py", # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason. "anthropic/_compat.py", @@ -1488,8 +1516,9 @@ class Model(BaseModel): respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) - response = await self.client.post("/foo", cast_to=Model, stream=True) - assert isinstance(response, AsyncStream) + stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model]) + assert isinstance(stream, AsyncStream) + await stream.response.aclose() @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio @@ -1539,6 +1568,29 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte calculated = client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] + @mock.patch("anthropic._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + async def test_streaming_response(self) -> None: + response = await self.client.post( + "/v1/complete", + body=dict( + max_tokens_to_sample=300, + model="claude-2.1", + prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", + ), + cast_to=AsyncAPIResponse[bytes], + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, + ) + + assert not cast(Any, response.is_closed) + assert _get_open_connections(self.client) == 1 + + async for _ in response.iter_bytes(): + ... + + assert cast(Any, response.is_closed) + assert _get_open_connections(self.client) == 0 + @mock.patch("anthropic._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: @@ -1553,7 +1605,7 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", ), cast_to=httpx.Response, - options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) assert _get_open_connections(self.client) == 0 @@ -1572,7 +1624,7 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) prompt="\n\nHuman:Where can I get a good coffee in my neighbourhood?\n\nAssistant:", ), cast_to=httpx.Response, - options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) assert _get_open_connections(self.client) == 0 diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 00000000..9b4be8f5 --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,50 @@ +from typing import List + +import httpx +import pytest + +from anthropic._response import ( + APIResponse, + BaseAPIResponse, + AsyncAPIResponse, + BinaryAPIResponse, + AsyncBinaryAPIResponse, + extract_response_type, +) + + +class ConcreteBaseAPIResponse(APIResponse[bytes]): + ... + + +class ConcreteAPIResponse(APIResponse[List[str]]): + ... + + +class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): + ... + + +def test_extract_response_type_direct_classes() -> None: + assert extract_response_type(BaseAPIResponse[str]) == str + assert extract_response_type(APIResponse[str]) == str + assert extract_response_type(AsyncAPIResponse[str]) == str + + +def test_extract_response_type_direct_class_missing_type_arg() -> None: + with pytest.raises( + RuntimeError, + match="Expected type to have a type argument at index 0 but it did not", + ): + extract_response_type(AsyncAPIResponse) + + +def test_extract_response_type_concrete_subclasses() -> None: + assert extract_response_type(ConcreteBaseAPIResponse) == bytes + assert extract_response_type(ConcreteAPIResponse) == List[str] + assert extract_response_type(ConcreteAsyncAPIResponse) == httpx.Response + + +def test_extract_response_type_binary_response() -> None: + assert extract_response_type(BinaryAPIResponse) == bytes + assert extract_response_type(AsyncBinaryAPIResponse) == bytes diff --git a/tests/utils.py b/tests/utils.py index 4c9df902..35b0965a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import inspect import traceback import contextlib from typing import Any, TypeVar, Iterator, cast @@ -68,6 +69,8 @@ def assert_matches_type( assert isinstance(value, bool) elif origin == float: assert isinstance(value, float) + elif origin == bytes: + assert isinstance(value, bytes) elif origin == datetime: assert isinstance(value, datetime) elif origin == date: @@ -100,6 +103,8 @@ def assert_matches_type( elif issubclass(origin, BaseModel): assert isinstance(value, type_) assert assert_matches_model(type_, cast(Any, value), path=path) + elif inspect.isclass(origin) and origin.__name__ == "HttpxBinaryResponseContent": + assert value.__class__.__name__ == "HttpxBinaryResponseContent" else: assert None, f"Unhandled field type: {type_}"