From d913ba35eb3a95e80154b9d35e4c0a9f4a8dfeb1 Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Mon, 13 Jan 2025 15:20:57 +0000 Subject: [PATCH] feat(beta): add streaming helpers for beta messages (#819) --- src/anthropic/lib/streaming/__init__.py | 13 + src/anthropic/lib/streaming/_beta_messages.py | 385 ++++++++++++++++++ src/anthropic/lib/streaming/_beta_types.py | 65 +++ .../resources/beta/messages/messages.py | 121 ++++++ 4 files changed, 584 insertions(+) create mode 100644 src/anthropic/lib/streaming/_beta_messages.py create mode 100644 src/anthropic/lib/streaming/_beta_types.py diff --git a/src/anthropic/lib/streaming/__init__.py b/src/anthropic/lib/streaming/__init__.py index 0ab41209..103fff58 100644 --- a/src/anthropic/lib/streaming/__init__.py +++ b/src/anthropic/lib/streaming/__init__.py @@ -11,3 +11,16 @@ MessageStreamManager as MessageStreamManager, AsyncMessageStreamManager as AsyncMessageStreamManager, ) +from ._beta_types import ( + BetaTextEvent as BetaTextEvent, + BetaInputJsonEvent as BetaInputJsonEvent, + BetaMessageStopEvent as BetaMessageStopEvent, + BetaMessageStreamEvent as BetaMessageStreamEvent, + BetaContentBlockStopEvent as BetaContentBlockStopEvent, +) +from ._beta_messages import ( + BetaMessageStream as BetaMessageStream, + BetaAsyncMessageStream as BetaAsyncMessageStream, + BetaMessageStreamManager as BetaMessageStreamManager, + BetaAsyncMessageStreamManager as BetaAsyncMessageStreamManager, +) diff --git a/src/anthropic/lib/streaming/_beta_messages.py b/src/anthropic/lib/streaming/_beta_messages.py new file mode 100644 index 00000000..48e419e9 --- /dev/null +++ b/src/anthropic/lib/streaming/_beta_messages.py @@ -0,0 +1,385 @@ +from __future__ import annotations + +from types import TracebackType +from typing import TYPE_CHECKING, Any, Callable, cast +from typing_extensions import Self, Iterator, Awaitable, AsyncIterator, assert_never + +import httpx + +from ..._utils import consume_sync_iterator, consume_async_iterator +from ..._models import build, construct_type +from ._beta_types import ( + BetaTextEvent, + BetaInputJsonEvent, + BetaMessageStopEvent, + BetaMessageStreamEvent, + BetaContentBlockStopEvent, +) +from ..._streaming import Stream, AsyncStream +from ...types.beta import BetaMessage, BetaContentBlock, BetaRawMessageStreamEvent + + +class BetaMessageStream: + text_stream: Iterator[str] + """Iterator over just the text deltas in the stream. + + ```py + for text in stream.text_stream: + print(text, end="", flush=True) + print() + ``` + """ + + def __init__(self, raw_stream: Stream[BetaRawMessageStreamEvent]) -> None: + self._raw_stream = raw_stream + self.text_stream = self.__stream_text__() + self._iterator = self.__stream__() + self.__final_message_snapshot: BetaMessage | None = None + + @property + def response(self) -> httpx.Response: + return self._raw_stream.response + + def __next__(self) -> BetaMessageStreamEvent: + return self._iterator.__next__() + + def __iter__(self) -> Iterator[BetaMessageStreamEvent]: + for item in self._iterator: + yield item + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + """ + Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + self._raw_stream.close() + + def get_final_message(self) -> BetaMessage: + """Waits until the stream has been read to completion and returns + the accumulated `Message` object. + """ + self.until_done() + assert self.__final_message_snapshot is not None + return self.__final_message_snapshot + + def get_final_text(self) -> str: + """Returns all `text` content blocks concatenated together. + + > [!NOTE] + > Currently the API will only respond with a single content block. + + Will raise an error if no `text` content blocks were returned. + """ + message = self.get_final_message() + text_blocks: list[str] = [] + for block in message.content: + if block.type == "text": + text_blocks.append(block.text) + + if not text_blocks: + raise RuntimeError("Expected to have received at least 1 text block") + + return "".join(text_blocks) + + def until_done(self) -> None: + """Blocks until the stream has been consumed""" + consume_sync_iterator(self) + + # properties + @property + def current_message_snapshot(self) -> BetaMessage: + assert self.__final_message_snapshot is not None + return self.__final_message_snapshot + + def __stream__(self) -> Iterator[BetaMessageStreamEvent]: + for sse_event in self._raw_stream: + self.__final_message_snapshot = accumulate_event( + event=sse_event, + current_snapshot=self.__final_message_snapshot, + ) + + events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot) + for event in events_to_fire: + yield event + + def __stream_text__(self) -> Iterator[str]: + for chunk in self: + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + yield chunk.delta.text + + +class BetaMessageStreamManager: + """Wrapper over MessageStream that is returned by `.stream()`. + + ```py + with client.beta.messages.stream(...) as stream: + for chunk in stream: + ... + ``` + """ + + def __init__( + self, + api_request: Callable[[], Stream[BetaRawMessageStreamEvent]], + ) -> None: + self.__stream: BetaMessageStream | None = None + self.__api_request = api_request + + def __enter__(self) -> BetaMessageStream: + raw_stream = self.__api_request() + self.__stream = BetaMessageStream(raw_stream) + return self.__stream + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.__stream is not None: + self.__stream.close() + + +class BetaAsyncMessageStream: + text_stream: AsyncIterator[str] + """Async iterator over just the text deltas in the stream. + + ```py + async for text in stream.text_stream: + print(text, end="", flush=True) + print() + ``` + """ + + def __init__(self, raw_stream: AsyncStream[BetaRawMessageStreamEvent]) -> None: + self._raw_stream = raw_stream + self.text_stream = self.__stream_text__() + self._iterator = self.__stream__() + self.__final_message_snapshot: BetaMessage | None = None + + @property + def response(self) -> httpx.Response: + return self._raw_stream.response + + async def __anext__(self) -> BetaMessageStreamEvent: + return await self._iterator.__anext__() + + async def __aiter__(self) -> AsyncIterator[BetaMessageStreamEvent]: + async for item in self._iterator: + yield item + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + + async def close(self) -> None: + """ + Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + await self._raw_stream.close() + + async def get_final_message(self) -> BetaMessage: + """Waits until the stream has been read to completion and returns + the accumulated `Message` object. + """ + await self.until_done() + assert self.__final_message_snapshot is not None + return self.__final_message_snapshot + + async def get_final_text(self) -> str: + """Returns all `text` content blocks concatenated together. + + > [!NOTE] + > Currently the API will only respond with a single content block. + + Will raise an error if no `text` content blocks were returned. + """ + message = await self.get_final_message() + text_blocks: list[str] = [] + for block in message.content: + if block.type == "text": + text_blocks.append(block.text) + + if not text_blocks: + raise RuntimeError("Expected to have received at least 1 text block") + + return "".join(text_blocks) + + async def until_done(self) -> None: + """Waits until the stream has been consumed""" + await consume_async_iterator(self) + + # properties + @property + def current_message_snapshot(self) -> BetaMessage: + assert self.__final_message_snapshot is not None + return self.__final_message_snapshot + + async def __stream__(self) -> AsyncIterator[BetaMessageStreamEvent]: + async for sse_event in self._raw_stream: + self.__final_message_snapshot = accumulate_event( + event=sse_event, + current_snapshot=self.__final_message_snapshot, + ) + + events_to_fire = build_events(event=sse_event, message_snapshot=self.current_message_snapshot) + for event in events_to_fire: + yield event + + async def __stream_text__(self) -> AsyncIterator[str]: + async for chunk in self: + if chunk.type == "content_block_delta" and chunk.delta.type == "text_delta": + yield chunk.delta.text + + +class BetaAsyncMessageStreamManager: + """Wrapper over BetaAsyncMessageStream that is returned by `.stream()` + so that an async context manager can be used without `await`ing the + original client call. + + ```py + async with client.beta.messages.stream(...) as stream: + async for chunk in stream: + ... + ``` + """ + + def __init__( + self, + api_request: Awaitable[AsyncStream[BetaRawMessageStreamEvent]], + ) -> None: + self.__stream: BetaAsyncMessageStream | None = None + self.__api_request = api_request + + async def __aenter__(self) -> BetaAsyncMessageStream: + raw_stream = await self.__api_request + self.__stream = BetaAsyncMessageStream(raw_stream) + return self.__stream + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.__stream is not None: + await self.__stream.close() + + +def build_events( + *, + event: BetaRawMessageStreamEvent, + message_snapshot: BetaMessage, +) -> list[BetaMessageStreamEvent]: + events_to_fire: list[BetaMessageStreamEvent] = [] + + if event.type == "message_start": + events_to_fire.append(event) + elif event.type == "message_delta": + events_to_fire.append(event) + elif event.type == "message_stop": + events_to_fire.append(build(BetaMessageStopEvent, type="message_stop", message=message_snapshot)) + elif event.type == "content_block_start": + events_to_fire.append(event) + elif event.type == "content_block_delta": + events_to_fire.append(event) + + content_block = message_snapshot.content[event.index] + if event.delta.type == "text_delta" and content_block.type == "text": + events_to_fire.append( + build( + BetaTextEvent, + type="text", + text=event.delta.text, + snapshot=content_block.text, + ) + ) + elif event.delta.type == "input_json_delta" and content_block.type == "tool_use": + events_to_fire.append( + build( + BetaInputJsonEvent, + type="input_json", + partial_json=event.delta.partial_json, + snapshot=content_block.input, + ) + ) + elif event.type == "content_block_stop": + content_block = message_snapshot.content[event.index] + + events_to_fire.append( + build(BetaContentBlockStopEvent, type="content_block_stop", index=event.index, content_block=content_block), + ) + else: + # we only want exhaustive checking for linters, not at runtime + if TYPE_CHECKING: # type: ignore[unreachable] + assert_never(event) + + return events_to_fire + + +JSON_BUF_PROPERTY = "__json_buf" + + +def accumulate_event( + *, + event: BetaRawMessageStreamEvent, + current_snapshot: BetaMessage | None, +) -> BetaMessage: + if current_snapshot is None: + if event.type == "message_start": + return BetaMessage.construct(**cast(Any, event.message.to_dict())) + + raise RuntimeError(f'Unexpected event order, got {event.type} before "message_start"') + + if event.type == "content_block_start": + # TODO: check index + current_snapshot.content.append( + cast( + BetaContentBlock, + construct_type(type_=BetaContentBlock, value=event.content_block.model_dump()), + ), + ) + elif event.type == "content_block_delta": + content = current_snapshot.content[event.index] + if content.type == "text" and event.delta.type == "text_delta": + content.text += event.delta.text + elif content.type == "tool_use" and event.delta.type == "input_json_delta": + from jiter import from_json + + # we need to keep track of the raw JSON string as well so that we can + # re-parse it for each delta, for now we just store it as an untyped + # property on the snapshot + json_buf = cast(bytes, getattr(content, JSON_BUF_PROPERTY, b"")) + json_buf += bytes(event.delta.partial_json, "utf-8") + + if json_buf: + content.input = from_json(json_buf, partial_mode=True) + + setattr(content, JSON_BUF_PROPERTY, json_buf) + elif event.type == "message_delta": + current_snapshot.stop_reason = event.delta.stop_reason + current_snapshot.stop_sequence = event.delta.stop_sequence + current_snapshot.usage.output_tokens = event.usage.output_tokens + + return current_snapshot diff --git a/src/anthropic/lib/streaming/_beta_types.py b/src/anthropic/lib/streaming/_beta_types.py new file mode 100644 index 00000000..a2a0bf6b --- /dev/null +++ b/src/anthropic/lib/streaming/_beta_types.py @@ -0,0 +1,65 @@ +from typing import Union +from typing_extensions import Literal + +from ..._models import BaseModel +from ...types.beta import ( + BetaMessage, + BetaContentBlock, + BetaRawMessageStopEvent, + BetaRawMessageDeltaEvent, + BetaRawMessageStartEvent, + BetaRawContentBlockStopEvent, + BetaRawContentBlockDeltaEvent, + BetaRawContentBlockStartEvent, +) + + +class BetaTextEvent(BaseModel): + type: Literal["text"] + + text: str + """The text delta""" + + snapshot: str + """The entire accumulated text""" + + +class BetaInputJsonEvent(BaseModel): + type: Literal["input_json"] + + partial_json: str + """A partial JSON string delta + + e.g. `'"San Francisco,'` + """ + + snapshot: object + """The currently accumulated parsed object. + + + e.g. `{'location': 'San Francisco, CA'}` + """ + + +class BetaMessageStopEvent(BetaRawMessageStopEvent): + type: Literal["message_stop"] + + message: BetaMessage + + +class BetaContentBlockStopEvent(BetaRawContentBlockStopEvent): + type: Literal["content_block_stop"] + + content_block: BetaContentBlock + + +BetaMessageStreamEvent = Union[ + BetaTextEvent, + BetaInputJsonEvent, + BetaRawMessageStartEvent, + BetaRawMessageDeltaEvent, + BetaMessageStopEvent, + BetaRawContentBlockStartEvent, + BetaRawContentBlockDeltaEvent, + BetaContentBlockStopEvent, +] diff --git a/src/anthropic/resources/beta/messages/messages.py b/src/anthropic/resources/beta/messages/messages.py index 62582d47..7aa89187 100644 --- a/src/anthropic/resources/beta/messages/messages.py +++ b/src/anthropic/resources/beta/messages/messages.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import List, Union, Iterable +from functools import partial from itertools import chain from typing_extensions import Literal, overload @@ -32,6 +33,7 @@ from ...._streaming import Stream, AsyncStream from ....types.beta import message_create_params, message_count_tokens_params from ...._base_client import make_request_options +from ....lib.streaming import BetaMessageStreamManager, BetaAsyncMessageStreamManager from ....types.model_param import ModelParam from ....types.beta.beta_message import BetaMessage from ....types.anthropic_beta_param import AnthropicBetaParam @@ -922,6 +924,67 @@ def create( stream_cls=Stream[BetaRawMessageStreamEvent], ) + def stream( + self, + *, + max_tokens: int, + messages: Iterable[BetaMessageParam], + model: ModelParam, + metadata: BetaMetadataParam | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + system: Union[str, Iterable[BetaTextBlockParam]] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: BetaToolChoiceParam | NotGiven = NOT_GIVEN, + tools: Iterable[BetaToolUnionParam] | NotGiven = NOT_GIVEN, + top_k: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + betas: List[AnthropicBetaParam] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BetaMessageStreamManager: + """Create a Message stream""" + if not is_given(timeout) and self._client.timeout == DEFAULT_TIMEOUT: + timeout = 600 + + extra_headers = { + "X-Stainless-Stream-Helper": "beta.messages", + **strip_not_given({"anthropic-beta": ",".join(str(e) for e in betas) if is_given(betas) else NOT_GIVEN}), + **(extra_headers or {}), + } + make_request = partial( + self._post, + "/v1/messages?beta=true", + body=maybe_transform( + { + "max_tokens": max_tokens, + "messages": messages, + "model": model, + "metadata": metadata, + "stop_sequences": stop_sequences, + "system": system, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "tools": tools, + "tool_choice": tool_choice, + "stream": True, + }, + message_create_params.MessageCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BetaMessage, + stream=True, + stream_cls=Stream[BetaRawMessageStreamEvent], + ) + return BetaMessageStreamManager(make_request) + + def count_tokens( self, *, @@ -2030,6 +2093,64 @@ async def create( stream_cls=AsyncStream[BetaRawMessageStreamEvent], ) + def stream( + self, + *, + max_tokens: int, + messages: Iterable[BetaMessageParam], + model: ModelParam, + metadata: BetaMetadataParam | NotGiven = NOT_GIVEN, + stop_sequences: List[str] | NotGiven = NOT_GIVEN, + system: Union[str, Iterable[BetaTextBlockParam]] | NotGiven = NOT_GIVEN, + temperature: float | NotGiven = NOT_GIVEN, + tool_choice: BetaToolChoiceParam | NotGiven = NOT_GIVEN, + tools: Iterable[BetaToolUnionParam] | NotGiven = NOT_GIVEN, + top_k: int | NotGiven = NOT_GIVEN, + top_p: float | NotGiven = NOT_GIVEN, + betas: List[AnthropicBetaParam] | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> BetaAsyncMessageStreamManager: + if not is_given(timeout) and self._client.timeout == DEFAULT_TIMEOUT: + timeout = 600 + + extra_headers = { + "X-Stainless-Stream-Helper": "beta.messages", + **strip_not_given({"anthropic-beta": ",".join(str(e) for e in betas) if is_given(betas) else NOT_GIVEN}), + **(extra_headers or {}), + } + request = self._post( + "/v1/messages", + body=maybe_transform( + { + "max_tokens": max_tokens, + "messages": messages, + "model": model, + "metadata": metadata, + "stop_sequences": stop_sequences, + "system": system, + "temperature": temperature, + "top_k": top_k, + "top_p": top_p, + "tools": tools, + "tool_choice": tool_choice, + "stream": True, + }, + message_create_params.MessageCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=BetaMessage, + stream=True, + stream_cls=AsyncStream[BetaRawMessageStreamEvent], + ) + return BetaAsyncMessageStreamManager(request) + async def count_tokens( self, *,