Skip to content

Commit

Permalink
Add LLM-assisted retries (#288)
Browse files Browse the repository at this point in the history
* Add apply, aapply functions

* Split _iter_streamed_tool_calls from parse_streamed_tool_calls for OpenaiChatModel

* Add _RawMessage class

* Tidy: reuse tool_calls var across Parallel vs not

* Join openai streamed tool call and include in error message

* WIP: include messages in error, retry within prompt

* WIP: Add TODOs

* Switch FunctionResultMessage to use unique_id not FunctionCall

* Rename unique_id -> function_call_id

* Revert "Rename unique_id -> function_call_id"

This reverts commit 5b4ba4d.

* Revert "Switch FunctionResultMessage to use unique_id not FunctionCall"

This reverts commit 80de9fc.

* Add ToolResultMessage, parent of FunctionResultMessage

* Add StringNotAllowedError. Update StructuredOutputError

* Move retry logic into RetryChatModel

* Add max_retries param to chatprompt

* Move error handling into _parse_streamed_tool_calls

* Handle str content in ToolResultMessage

* Explicitly pass ValidationError to StructuredOutputError

* Do not show _RawMessage in error msg

* Rename StructuredOutputError to ToolSchemaParseError

* Tidy by making _make_retry_messages

* Update TODOs

* Address mypy errors for OpenaiChatModel

* Add init to _RawMessage

* Remove unused cached_response from OpenaiChatModel.acomplete

* Update LitellmChatModel for new error handling

* Make AnthropicChatModel parse_streamed_functions private

* Move error handling into _parse_streamed_tool_calls for AnthropicChatModel

* Implement _join_streamed_response_to_message

* Fix: avoid pydantic class name reused error

* Add tests for AnthropicChatModel error handling, and fix it

* Rename test for ToolSchemaParseError

* Add tests for LitellmChatModel error handling

* Fix mypy errors

* Implement serialization of _RawMessage w/ Anthropic. Fix tool use id

* Add tests for RetryChatModel with OpenaiChatModel

* Add tests for RetryChatModel with AnthropicChatModel

* Add tests for RetryChatModel with LitellmChatModel

* Add tests for prompt max_retries param

* Add Retrying docs page
  • Loading branch information
jackmpcollins authored Aug 12, 2024
1 parent e017ce8 commit c7cb858
Show file tree
Hide file tree
Showing 19 changed files with 891 additions and 170 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Easily integrate Large Language Models into your Python code. Simply use the `@p
- [Asyncio]. Simply use `async def` when defining a magentic function.
- [Streaming] structured outputs to use them as they are being generated.
- [Vision] to easily get structured outputs from images.
- [LLM-Assisted Retries] to improve LLM adherence to complex output schemas.
- Multiple LLM providers including OpenAI and Anthropic. See [Configuration].
- [Type Annotations] to work nicely with linters and IDEs.

Expand Down Expand Up @@ -187,6 +188,7 @@ LLM-powered functions created using `@prompt`, `@chatprompt` and `@prompt_chain`
[Asyncio]: https://magentic.dev/asyncio
[Streaming]: https://magentic.dev/streaming
[Vision]: https://magentic.dev/vision
[LLM-Assisted Retries]: https://magentic.dev/retrying.md
[Configuration]: https://magentic.dev/configuration
[Type Annotations]: https://magentic.dev/type-checking

Expand Down
2 changes: 2 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Easily integrate Large Language Models into your Python code. Simply use the `@p
- [Asyncio]. Simply use `async def` when defining a magentic function.
- [Streaming] structured outputs to use them as they are being generated.
- [Vision] to easily get structured outputs from images.
- [LLM-Assisted Retries] to improve LLM adherence to complex output schemas.
- Multiple LLM providers including OpenAI and Anthropic. See [Configuration].
- [Type Annotations] to work nicely with linters and IDEs.

Expand Down Expand Up @@ -187,5 +188,6 @@ LLM-powered functions created using `@prompt`, `@chatprompt` and `@prompt_chain`
[Asyncio]: asyncio.md
[Streaming]: streaming.md
[Vision]: vision.md
[LLM-assisted Retries]: retrying.md
[Configuration]: configuration.md
[Type Annotations]: type-checking.md
74 changes: 74 additions & 0 deletions docs/retrying.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Retrying

## LLM-Assisted Retries

Occasionally the LLM returns an output that cannot be parsed into any of the output types or function calls that were requested. Additionally, the pydantic models you define might have extra validation that is not represented by the type annotations alone. In these cases, LLM-assisted retries can be used to automatically resubmit the output as well as the associated error message back to the LLM, giving it another opportunity with more information to meet the output schema requirements.

To enable retries, simply set the `max_retries` parameter to a non-zero value in `@prompt` or `@chatprompt`.

In this example

- the LLM first returns a country that is not Ireland
- then the pydantic model validation fails with error "Country must be Ireland"
- the original output as well as a message containing the error are resubmitted to the LLM
- the LLM correctly meets the output requirement returning "Ireland"

```python
from typing import Annotated

from magentic import prompt
from pydantic import AfterValidator, BaseModel


def assert_is_ireland(v: str) -> str:
if v != "Ireland":
raise ValueError("Country must be Ireland")
return v


class Country(BaseModel):
name: Annotated[str, AfterValidator(assert_is_ireland)]
capital: str


@prompt(
"Return a country",
max_retries=3,
)
def get_country() -> Country: ...


get_country()
# 05:13:55.607 Calling prompt-function get_country
# 05:13:55.622 LLM-assisted retries enabled. Max 3
# 05:13:55.627 Chat Completion with 'gpt-4o' [LLM]
# 05:13:56.309 streaming response from 'gpt-4o' took 0.11s [LLM]
# 05:13:56.310 Retrying Chat Completion. Attempt 1.
# 05:13:56.322 Chat Completion with 'gpt-4o' [LLM]
# 05:13:57.456 streaming response from 'gpt-4o' took 0.00s [LLM]
#
# Country(name='Ireland', capital='Dublin')
```

LLM-Assisted retries are intended to address cases where the LLM failed to generate valid output. Errors due to LLM provider rate limiting, internet connectivity issues, or other issues that cannot be solved by reprompting the LLM should be handled using other methods. For example [jd/tenacity](https://github.com/jd/tenacity) or [hynek/stamina](https://github.com/hynek/stamina) to retry a Python function.

### RetryChatModel

Under the hood, LLM-assisted retries are implemented using the `RetryChatModel` which wraps any other `ChatModel`, catches exceptions, and resubmits them to the LLM. To implement your own retry handling you can follow the pattern of this class. Please file a [GitHub issue](https://github.com/jackmpcollins/magentic/issues) if you encounter exceptions that should be included in the LLM-assisted retries.

To use the `RetryChatModel` directly rather than via the `max_retries` parameter, simply pass it as the `model` argument to the decorator. Extending the example above

```python
from magentic import OpenaiChatModel
from magentic.chat_model.retry_chat_model import RetryChatModel


@prompt(
"Return a country",
model=RetryChatModel(OpenaiChatModel("gpt-4o-mini"), max_retries=3),
)
def get_country() -> Country: ...


get_country()
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ nav:
- asyncio.md
- streaming.md
- vision.md
- retrying.md
- logging-and-tracing.md
- configuration.md
- type-checking.md
Expand Down
164 changes: 107 additions & 57 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from magentic.chat_model.base import (
ChatModel,
StructuredOutputError,
ToolSchemaParseError,
avalidate_str_content,
validate_str_content,
)
Expand All @@ -28,6 +28,7 @@
ToolResultMessage,
Usage,
UserMessage,
_RawMessage,
)
from magentic.function_call import (
AsyncParallelFunctionCall,
Expand All @@ -38,15 +39,20 @@
from magentic.streaming import (
AsyncStreamedStr,
StreamedStr,
aapply,
achain,
agroupby,
apeek,
apply,
async_iter,
peek,
)
from magentic.typing import is_any_origin_subclass, is_origin_subclass

try:
import anthropic
from anthropic.lib.streaming import MessageStreamEvent
from anthropic.lib.streaming._messages import accumulate_event
from anthropic.types import (
ContentBlockDeltaEvent,
ContentBlockStartEvent,
Expand All @@ -72,6 +78,12 @@ def message_to_anthropic_message(message: Message[Any]) -> MessageParam:
raise NotImplementedError(type(message))


@message_to_anthropic_message.register(_RawMessage)
def _(message: _RawMessage[Any]) -> MessageParam:
# TODO: Validate the message content
return message.content # type: ignore[no-any-return]


@message_to_anthropic_message.register
def _(message: UserMessage) -> MessageParam:
return {"role": AnthropicMessageRole.USER.value, "content": message.content}
Expand Down Expand Up @@ -138,14 +150,18 @@ def _(message: AssistantMessage[Any]) -> MessageParam:

@message_to_anthropic_message.register(ToolResultMessage)
def _(message: ToolResultMessage[Any]) -> MessageParam:
function_schema = function_schema_for_type(type(message.content))
if isinstance(message.content, str):
content = message.content
else:
function_schema = function_schema_for_type(type(message.content))
content = json.loads(function_schema.serialize_args(message.content))
return {
"role": AnthropicMessageRole.USER.value,
"content": [
{
"type": "tool_result",
"tool_use_id": message.tool_call_id,
"content": json.loads(function_schema.serialize_args(message.content)),
"content": content,
}
],
}
Expand Down Expand Up @@ -207,38 +223,90 @@ async def aparse_tool_call(self, chunks: AsyncIterable[MessageStreamEvent]) -> T
)


def parse_streamed_tool_calls(
def _iter_streamed_tool_calls(
response: Iterable[MessageStreamEvent],
tool_schemas: Iterable[FunctionToolSchema[T]],
) -> Iterator[T]:
) -> Iterator[Iterator[ContentBlockStartEvent | ContentBlockDeltaEvent]]:
all_tool_call_chunks = (
cast(ContentBlockStartEvent | ContentBlockDeltaEvent, chunk)
for chunk in response
if chunk.type in ("content_block_start", "content_block_delta")
)
for _, tool_call_chunks in groupby(all_tool_call_chunks, lambda x: x.index):
first_chunk = next(tool_call_chunks)
assert first_chunk.type == "content_block_start" # noqa: S101
assert first_chunk.content_block.type == "tool_use" # noqa: S101
tool_schema = select_tool_schema(first_chunk.content_block, tool_schemas)
yield tool_schema.parse_tool_call(tool_call_chunks) # noqa: B031
yield tool_call_chunks


async def aparse_streamed_tool_calls(
async def _aiter_streamed_tool_calls(
response: AsyncIterable[MessageStreamEvent],
tool_schemas: Iterable[AsyncFunctionToolSchema[T]],
) -> AsyncIterator[T]:
) -> AsyncIterator[AsyncIterator[ContentBlockStartEvent | ContentBlockDeltaEvent]]:
all_tool_call_chunks = (
cast(ContentBlockStartEvent | ContentBlockDeltaEvent, chunk)
async for chunk in response
if chunk.type in ("content_block_start", "content_block_delta")
)
async for _, tool_call_chunks in agroupby(all_tool_call_chunks, lambda x: x.index):
first_chunk = await anext(tool_call_chunks)
assert first_chunk.type == "content_block_start" # noqa: S101
assert first_chunk.content_block.type == "tool_use" # noqa: S101
tool_schema = select_tool_schema(first_chunk.content_block, tool_schemas)
yield await tool_schema.aparse_tool_call(tool_call_chunks)
yield tool_call_chunks


def _join_streamed_response_to_message(
response: list[MessageStreamEvent],
) -> _RawMessage[MessageParam]:
snapshot = None
for event in response:
snapshot = accumulate_event(
event=event, # type: ignore[arg-type]
current_snapshot=snapshot,
)
assert snapshot is not None # noqa: S101
snapshot_content = snapshot.model_dump()["content"]
return _RawMessage({"role": snapshot.role, "content": snapshot_content})


def _parse_streamed_tool_calls(
response: Iterable[MessageStreamEvent],
tool_schemas: Iterable[FunctionToolSchema[T]],
) -> Iterator[T]:
cached_response: list[MessageStreamEvent] = []
response = apply(cached_response.append, response)
try:
for tool_call_chunks in _iter_streamed_tool_calls(response):
first_chunk, tool_call_chunks = peek(tool_call_chunks)
assert first_chunk.type == "content_block_start" # noqa: S101
assert first_chunk.content_block.type == "tool_use" # noqa: S101
tool_schema = select_tool_schema(first_chunk.content_block, tool_schemas)
tool_call = tool_schema.parse_tool_call(tool_call_chunks)
yield tool_call
# TODO: Catch/raise unknown tool call error here
except ValidationError as e:
raw_message = _join_streamed_response_to_message(cached_response)
raise ToolSchemaParseError(
output_message=raw_message,
tool_call_id=raw_message.content["content"][0]["id"], # type: ignore[index,unused-ignore]
validation_error=e,
) from e


async def _aparse_streamed_tool_calls(
response: AsyncIterable[MessageStreamEvent],
tool_schemas: Iterable[AsyncFunctionToolSchema[T]],
) -> AsyncIterator[T]:
cached_response: list[MessageStreamEvent] = []
response = aapply(cached_response.append, response)
try:
async for tool_call_chunks in _aiter_streamed_tool_calls(response):
first_chunk, tool_call_chunks = await apeek(tool_call_chunks)
assert first_chunk.type == "content_block_start" # noqa: S101
assert first_chunk.content_block.type == "tool_use" # noqa: S101
tool_schema = select_tool_schema(first_chunk.content_block, tool_schemas)
tool_call = await tool_schema.aparse_tool_call(tool_call_chunks)
yield tool_call
# TODO: Catch/raise unknown tool call error here
except ValidationError as e:
raw_message = _join_streamed_response_to_message(cached_response)
raise ToolSchemaParseError(
output_message=raw_message,
tool_call_id=raw_message.content["content"][0]["id"], # type: ignore[index,unused-ignore]
validation_error=e,
) from e


def _extract_system_message(
Expand Down Expand Up @@ -449,11 +517,11 @@ def _response_generator() -> Iterator[MessageStreamEvent]:
response = _response_generator()
usage_ref, response = _create_usage_ref(response)

message_start_chunk = next(response)
assert message_start_chunk.type == "message_start" # noqa: S101
first_chunk = next(response)
if first_chunk.type == "message_start":
first_chunk = next(response)
assert first_chunk.type == "content_block_start" # noqa: S101
response = chain([first_chunk], response)
response = chain([message_start_chunk, first_chunk], response)

if (
first_chunk.type == "content_block_start"
Expand All @@ -476,22 +544,14 @@ def _response_generator() -> Iterator[MessageStreamEvent]:
first_chunk.type == "content_block_start"
and first_chunk.content_block.type == "tool_use"
):
try:
if is_any_origin_subclass(output_types, ParallelFunctionCall):
content = ParallelFunctionCall(
parse_streamed_tool_calls(response, tool_schemas)
)
return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value]
# Take only the first tool_call, silently ignore extra chunks
# TODO: Create generator here that raises error or warns if multiple tool_calls
content = next(parse_streamed_tool_calls(response, tool_schemas))
tool_calls = _parse_streamed_tool_calls(response, tool_schemas)
if is_any_origin_subclass(output_types, ParallelFunctionCall):
content = ParallelFunctionCall(tool_calls)
return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value]
except ValidationError as e:
msg = (
"Failed to parse model output. You may need to update your prompt"
" to encourage the model to return a specific type."
)
raise StructuredOutputError(msg) from e
# Take only the first tool_call, silently ignore extra chunks
# TODO: Create generator here that raises error or warns if multiple tool_calls
content = next(tool_calls)
return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value]

msg = f"Could not determine response type for first chunk: {first_chunk.model_dump_json()}"
raise ValueError(msg)
Expand Down Expand Up @@ -568,11 +628,11 @@ async def _response_generator() -> AsyncIterator[MessageStreamEvent]:
response = _response_generator()
usage_ref, response = _create_usage_ref_async(response)

message_start_chunk = await anext(response)
assert message_start_chunk.type == "message_start" # noqa: S101
first_chunk = await anext(response)
if first_chunk.type == "message_start":
first_chunk = await anext(response)
assert first_chunk.type == "content_block_start" # noqa: S101
response = achain(async_iter([first_chunk]), response)
response = achain(async_iter([message_start_chunk, first_chunk]), response)

if (
first_chunk.type == "content_block_start"
Expand All @@ -595,24 +655,14 @@ async def _response_generator() -> AsyncIterator[MessageStreamEvent]:
first_chunk.type == "content_block_start"
and first_chunk.content_block.type == "tool_use"
):
try:
if is_any_origin_subclass(output_types, AsyncParallelFunctionCall):
content = AsyncParallelFunctionCall(
aparse_streamed_tool_calls(response, tool_schemas)
)
return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value]
# Take only the first tool_call, silently ignore extra chunks
# TODO: Create generator here that raises error or warns if multiple tool_calls
content = await anext(
aparse_streamed_tool_calls(response, tool_schemas)
)
tool_calls = _aparse_streamed_tool_calls(response, tool_schemas)
if is_any_origin_subclass(output_types, AsyncParallelFunctionCall):
content = AsyncParallelFunctionCall(tool_calls)
return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value]
except ValidationError as e:
msg = (
"Failed to parse model output. You may need to update your prompt"
" to encourage the model to return a specific type."
)
raise StructuredOutputError(msg) from e
# Take only the first tool_call, silently ignore extra chunks
# TODO: Create generator here that raises error or warns if multiple tool_calls
content = await anext(tool_calls)
return AssistantMessage._with_usage(content, usage_ref) # type: ignore[return-value]

msg = "Could not determine response type"
raise ValueError(msg)
Loading

0 comments on commit c7cb858

Please sign in to comment.