From 7a460107579d1f8016f02b57871b1b5d6ef0477b Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sat, 11 Jan 2025 20:46:58 -0800 Subject: [PATCH] Make Chat class public and add docs (#401) * Make chat module private. Make Chat class top-level importable * Use Self from typing_extensions for Chat * Add Chat.add_system_message * Deprecate Chat.from_prompt to remove Chat dep on prompt in future * Add TODO comments for Chat class * Add Chat docs page * Add Agent section to Chat docs * Reduce get_current_weather return to one line --- README.md | 7 +- docs/chat-prompting.md | 2 + docs/chat.md | 141 +++++++++++++++++++++++++++++ docs/function-calling.md | 7 +- docs/index.md | 7 +- mkdocs.yml | 1 + src/magentic/__init__.py | 1 + src/magentic/{chat.py => _chat.py} | 35 +++++-- src/magentic/prompt_chain.py | 23 ++++- tests/test_chat.py | 16 +--- 10 files changed, 193 insertions(+), 47 deletions(-) create mode 100644 docs/chat.md rename src/magentic/{chat.py => _chat.py} (83%) diff --git a/README.md b/README.md index 4d46106f..f8de0b42 100644 --- a/README.md +++ b/README.md @@ -154,12 +154,7 @@ from magentic import prompt_chain def get_current_weather(location, unit="fahrenheit"): """Get the current weather in a given location""" # Pretend to query an API - return { - "location": location, - "temperature": "72", - "unit": unit, - "forecast": ["sunny", "windy"], - } + return {"temperature": "72", "forecast": ["sunny", "windy"]} @prompt_chain( diff --git a/docs/chat-prompting.md b/docs/chat-prompting.md index 661d8618..11139578 100644 --- a/docs/chat-prompting.md +++ b/docs/chat-prompting.md @@ -1,5 +1,7 @@ # Chat Prompting +This page covers the `@chatprompt` decorator which can be used to define an LLM query template (message templates, optional LLM, functions, and return type). To manage an ongoing conversation with an LLM, see [Chat](chat.md). + ## @chatprompt The `@chatprompt` decorator works just like `@prompt` but allows you to pass chat messages as a template rather than a single text prompt. This can be used to provide a system message or for few-shot prompting where you provide example responses to guide the model's output. Format fields denoted by curly braces `{example}` will be filled in all messages (except `FunctionResultMessage`). diff --git a/docs/chat.md b/docs/chat.md new file mode 100644 index 00000000..549a4e2d --- /dev/null +++ b/docs/chat.md @@ -0,0 +1,141 @@ +# Chat + +This page covers the `Chat` class which can be used to manage an ongoing conversation with an LLM chat model. To define a reusable LLM query template use the `@chatprompt` decorator instead, see [Chat Prompting](chat-prompting.md). + +The `Chat` class represents an ongoing conversation with an LLM. It keeps track of the messages exchanged between the user and the model and allows submitting the conversation to the model to get a response. + +## Basic Usage + +```python +from magentic import Chat, OpenaiChatModel, UserMessage + +# Create a new Chat instance +chat = Chat( + messages=[UserMessage("Say hello")], + model=OpenaiChatModel("gpt-4o"), +) + +# Append a new user message +chat = chat.add_user_message("Actually, say goodbye!") +print(chat.messages) +# [UserMessage('Say hello'), UserMessage('Actually, say goodbye!')] + +# Submit the chat to the LLM to get a response +chat = chat.submit() +print(chat.last_message.content) +# 'Hello! Just kidding—goodbye!' +``` + +Note that all methods of `Chat` return a new `Chat` instance with the updated messages. This allows branching the conversation and keeping track of multiple conversation paths. + +The following methods are available to manually add messages to the chat by providing just the content of the message: + +- `add_system_message`: Adds a system message to the chat. +- `add_user_message`: Adds a user message to the chat. +- `add_assistant_message`: Adds an assistant message to the chat. + +There is also a generic `add_message` method to add a `Message` object to the chat. And the `submit` method is used to submit the chat to the LLM model which adds an `AssistantMessage` to the chat containing the model's response. + +## Function Calling + +Function calling can be done with the `Chat` class by providing the list of functions when creating the instance, similar to the `@chatprompt` decorator. Similarly, structured outputs can be returned by setting the `output_types` parameter. + +If the last message in the chat is an `AssistantMessage` containing a `FunctionCall` or `ParallelFunctionCall`, calling the `exec_function_call` method will execute the function call(s) and append the result(s) to the chat. Then, if needed, the chat can be submitted to the LLM again to get another response. + +```python hl_lines="23-25" +from magentic import ( + AssistantMessage, + Chat, + FunctionCall, + OpenaiChatModel, + UserMessage, +) + + +def get_current_weather(location, unit="fahrenheit"): + """Get the current weather in a given location""" + # Pretend to query an API + return {"temperature": "72", "forecast": ["sunny", "windy"]} + + +chat = Chat( + messages=[UserMessage("What's the weather like in Boston?")], + functions=[get_current_weather], + # `FunctionCall` must be in output_types to get `FunctionCall` outputs + output_types=[FunctionCall, str], + model=OpenaiChatModel("gpt-4o"), +) +chat = chat.submit() +print(chat.messages) +# [UserMessage("What's the weather like in Boston?"), +# AssistantMessage(FunctionCall(, 'Boston'))] + +# Execute the function call and append the result to the chat +chat = chat.exec_function_call() +print(chat.messages) +# [UserMessage("What's the weather like in Boston?"), +# AssistantMessage(FunctionCall(, 'Boston')), +# FunctionResultMessage({'location': 'Boston', 'temperature': '72', 'unit': 'fahrenheit', 'forecast': ['sunny', 'windy']}, +# FunctionCall(, 'Boston'))] + +# Submit the chat again to get the final LLM response +chat = chat.submit() +print(chat.messages) +# [UserMessage("What's the weather like in Boston?"), +# AssistantMessage(FunctionCall(, 'Boston')), +# FunctionResultMessage({'location': 'Boston', 'temperature': '72', 'unit': 'fahrenheit', 'forecast': ['sunny', 'windy']}, +# FunctionCall(, 'Boston')), +# AssistantMessage("The current weather in Boston is 72°F, and it's sunny with windy conditions.")] +``` + +## Streaming + +Streaming types such as `StreamedStr`, `StreamedOutput`, and `Iterable[T]` can be provided in the `output_types` parameter. When the `.submit()` method is called, an `AssistantMessage` containing the streamed type will be appended to the chat immediately. This allows the streamed type to be accessed and streamed from. For more information on streaming types, see [Streaming](streaming.md). + +```python +from magentic import Chat, UserMessage, StreamedStr + +chat = Chat( + messages=[UserMessage("Tell me about the Golden Gate Bridge.")], + output_types=[StreamedStr], +) +chat = chat.submit() +print(type(chat.last_message.content)) +# + +for chunk in chat.last_message.content: + print(chunk, end="") +# (streamed) 'The Golden Gate Bridge is an iconic suspension bridge... +``` + +## Asyncio + +The `Chat` class also support asynchronous usage through the following methods: + +- `asubmit`: Asynchronously submit the chat to the LLM model. +- `aexec_function_call`: Asynchronously execute the function call in the chat. This is required to handle the `AsyncParallelFunctionCall` output type. + +## Agent + +A very basic form of an agent can be created by running a loop that submits the chat to the LLM and executes function calls until some stop condition is met. + +```python +from magentic import Chat, FunctionCall, ParallelFunctionCall, UserMessage + + +def get_current_weather(location, unit="fahrenheit"): + """Get the current weather in a given location""" + # Pretend to query an API + return {"temperature": "72", "forecast": ["sunny", "windy"]} + + +chat = Chat( + messages=[UserMessage("What's the weather like in Boston?")], + functions=[get_current_weather], + output_types=[FunctionCall, str], +).submit() +while isinstance(chat.last_message.content, FunctionCall | ParallelFunctionCall): + chat = chat.exec_function_call().submit() +print(chat.last_message.content) +# 'The current weather in Boston is 72°F, with sunny and windy conditions.' +``` diff --git a/docs/function-calling.md b/docs/function-calling.md index fd1691bd..2b38a058 100644 --- a/docs/function-calling.md +++ b/docs/function-calling.md @@ -71,12 +71,7 @@ from magentic import prompt_chain def get_current_weather(location, unit="fahrenheit"): """Get the current weather in a given location""" # Pretend to query an API - return { - "location": location, - "temperature": "72", - "unit": unit, - "forecast": ["sunny", "windy"], - } + return {"temperature": "72", "forecast": ["sunny", "windy"]} @prompt_chain( diff --git a/docs/index.md b/docs/index.md index d71c1aaa..0d985f64 100644 --- a/docs/index.md +++ b/docs/index.md @@ -155,12 +155,7 @@ from magentic import prompt_chain def get_current_weather(location, unit="fahrenheit"): """Get the current weather in a given location""" # Pretend to query an API - return { - "location": location, - "temperature": "72", - "unit": unit, - "forecast": ["sunny", "windy"], - } + return {"temperature": "72", "forecast": ["sunny", "windy"]} @prompt_chain( diff --git a/mkdocs.yml b/mkdocs.yml index 86ee6d37..475a5fc3 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -83,6 +83,7 @@ nav: - Overview: index.md - structured-outputs.md - chat-prompting.md + - chat.md - function-calling.md - formatting.md - asyncio.md diff --git a/src/magentic/__init__.py b/src/magentic/__init__.py index 4dbc58ea..dfecd68c 100644 --- a/src/magentic/__init__.py +++ b/src/magentic/__init__.py @@ -1,3 +1,4 @@ +from ._chat import Chat as Chat from ._pydantic import ConfigDict as ConfigDict from ._pydantic import with_config as with_config from ._streamed_response import AsyncStreamedResponse as AsyncStreamedResponse diff --git a/src/magentic/chat.py b/src/magentic/_chat.py similarity index 83% rename from src/magentic/chat.py rename to src/magentic/_chat.py index 18046298..64a1b94f 100644 --- a/src/magentic/chat.py +++ b/src/magentic/_chat.py @@ -1,6 +1,8 @@ import inspect from collections.abc import Callable, Iterable, Sequence -from typing import Any, ParamSpec, TypeVar +from typing import Any, ParamSpec + +from typing_extensions import Self, deprecated from magentic.backend import get_chat_model from magentic.chat_model.base import ChatModel @@ -8,7 +10,9 @@ AssistantMessage, FunctionResultMessage, Message, + SystemMessage, UserMessage, + UserMessageContentBlock, ) from magentic.function_call import ( AsyncParallelFunctionCall, @@ -19,8 +23,6 @@ from magentic.streaming import async_iter, azip P = ParamSpec("P") -# TODO: Use `Self` from typing_extensions -Self = TypeVar("Self", bound="Chat") class Chat: @@ -39,6 +41,7 @@ class Chat: def __init__( self, messages: Sequence[Message[Any]] | None = None, + *, functions: Iterable[Callable[..., Any]] | None = None, output_types: Iterable[type[Any]] | None = None, model: ChatModel | None = None, @@ -49,6 +52,10 @@ def __init__( self._model = model @classmethod + @deprecated( + "Chat.from_prompt will be removed in a future version." + " Instead, use the regular init method, `Chat(messages, functions, output_types, model)`." + ) def from_prompt( cls: type[Self], prompt: BasePromptFunction[P, Any], @@ -75,7 +82,7 @@ def last_message(self) -> Message[Any]: def model(self) -> ChatModel: return self._model or get_chat_model() - def add_message(self: Self, message: Message[Any]) -> Self: + def add_message(self, message: Message[Any]) -> Self: """Add a message to the chat.""" return type(self)( messages=[*self._messages, message], @@ -84,15 +91,22 @@ def add_message(self: Self, message: Message[Any]) -> Self: model=self._model, # Keep `None` value if unset ) - def add_user_message(self: Self, content: str) -> Self: + def add_system_message(self, content: str) -> Self: + """Add a system message to the chat.""" + return self.add_message(SystemMessage(content=content)) + + def add_user_message( + self, content: str | Sequence[str | UserMessageContentBlock] + ) -> Self: """Add a user message to the chat.""" return self.add_message(UserMessage(content=content)) - def add_assistant_message(self: Self, content: Any) -> Self: + def add_assistant_message(self, content: Any) -> Self: """Add an assistant message to the chat.""" return self.add_message(AssistantMessage(content=content)) - def submit(self: Self) -> Self: + # TODO: Allow restricting functions and/or output types here + def submit(self) -> Self: """Request an LLM message to be added to the chat.""" output_message: AssistantMessage[Any] = self.model.complete( messages=self._messages, @@ -101,7 +115,7 @@ def submit(self: Self) -> Self: ) return self.add_message(output_message) - async def asubmit(self: Self) -> Self: + async def asubmit(self) -> Self: """Async version of `submit`.""" output_message: AssistantMessage[Any] = await self.model.acomplete( messages=self._messages, @@ -110,7 +124,8 @@ async def asubmit(self: Self) -> Self: ) return self.add_message(output_message) - def exec_function_call(self: Self) -> Self: + # TODO: Add optional error handling to this method, with param to toggle + def exec_function_call(self) -> Self: """If the last message is a function call, execute it and add the result.""" if isinstance(self.last_message.content, FunctionCall): function_call = self.last_message.content @@ -133,7 +148,7 @@ def exec_function_call(self: Self) -> Self: msg = "Last message is not a function call." raise TypeError(msg) - async def aexec_function_call(self: Self) -> Self: + async def aexec_function_call(self) -> Self: """Async version of `exec_function_call`.""" if isinstance(self.last_message.content, FunctionCall): function_call = self.last_message.content diff --git a/src/magentic/prompt_chain.py b/src/magentic/prompt_chain.py index 0d6434bd..cfe586ab 100644 --- a/src/magentic/prompt_chain.py +++ b/src/magentic/prompt_chain.py @@ -3,8 +3,9 @@ from functools import wraps from typing import Any, ParamSpec, TypeVar, cast -from magentic.chat import Chat +from magentic._chat import Chat from magentic.chat_model.base import ChatModel +from magentic.chat_model.message import UserMessage from magentic.function_call import FunctionCall from magentic.logger import logfire from magentic.prompt_function import AsyncPromptFunction, PromptFunction @@ -45,8 +46,15 @@ async def awrapper(*args: P.args, **kwargs: P.kwargs) -> Any: f"Calling async prompt-chain {func.__name__}", **func_signature.bind(*args, **kwargs).arguments, ): - chat = await Chat.from_prompt( - async_prompt_function, *args, **kwargs + chat = await Chat( + messages=[ + UserMessage( + content=async_prompt_function.format(*args, **kwargs) + ) + ], + functions=async_prompt_function.functions, + output_types=async_prompt_function.return_types, + model=async_prompt_function._model, # Keep `None` value if unset ).asubmit() num_calls = 0 while isinstance(chat.last_message.content, FunctionCall): @@ -79,7 +87,14 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: f"Calling prompt-chain {func.__name__}", **func_signature.bind(*args, **kwargs).arguments, ): - chat = Chat.from_prompt(prompt_function, *args, **kwargs).submit() + chat = Chat( + messages=[ + UserMessage(content=prompt_function.format(*args, **kwargs)) + ], + functions=prompt_function.functions, + output_types=prompt_function.return_types, + model=prompt_function._model, # Keep `None` value if unset + ).submit() num_calls = 0 while isinstance(chat.last_message.content, FunctionCall): if max_calls is not None and num_calls >= max_calls: diff --git a/tests/test_chat.py b/tests/test_chat.py index 549d396d..24b90725 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -2,7 +2,7 @@ import pytest -from magentic.chat import Chat +from magentic._chat import Chat from magentic.chat_model.message import ( AssistantMessage, FunctionResultMessage, @@ -13,26 +13,12 @@ FunctionCall, ParallelFunctionCall, ) -from magentic.prompt_function import prompt from magentic.streaming import async_iter if TYPE_CHECKING: from collections.abc import Awaitable -def test_chat_from_prompt(): - """Test creating a chat from a prompt function.""" - - def plus(a: int, b: int) -> int: - return a + b - - @prompt("What is {a} plus {b}?", functions=[plus]) - def add_text_numbers(a: str, b: str) -> int: ... - - chat = Chat.from_prompt(add_text_numbers, "one", "two") - assert chat.messages == [UserMessage(content="What is one plus two?")] - - def test_chat_add_message(): chat1 = Chat() chat2 = chat1.add_message(UserMessage(content="Hello"))