Skip to content

Commit

Permalink
Make Chat class public and add docs (#401)
Browse files Browse the repository at this point in the history
* Make chat module private. Make Chat class top-level importable

* Use Self from typing_extensions for Chat

* Add Chat.add_system_message

* Deprecate Chat.from_prompt to remove Chat dep on prompt in future

* Add TODO comments for Chat class

* Add Chat docs page

* Add Agent section to Chat docs

* Reduce get_current_weather return to one line
  • Loading branch information
jackmpcollins authored Jan 12, 2025
1 parent fd0d03e commit 7a46010
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 47 deletions.
7 changes: 1 addition & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,7 @@ from magentic import prompt_chain
def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location"""
# Pretend to query an API
return {
"location": location,
"temperature": "72",
"unit": unit,
"forecast": ["sunny", "windy"],
}
return {"temperature": "72", "forecast": ["sunny", "windy"]}


@prompt_chain(
Expand Down
2 changes: 2 additions & 0 deletions docs/chat-prompting.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Chat Prompting

This page covers the `@chatprompt` decorator which can be used to define an LLM query template (message templates, optional LLM, functions, and return type). To manage an ongoing conversation with an LLM, see [Chat](chat.md).

## @chatprompt

The `@chatprompt` decorator works just like `@prompt` but allows you to pass chat messages as a template rather than a single text prompt. This can be used to provide a system message or for few-shot prompting where you provide example responses to guide the model's output. Format fields denoted by curly braces `{example}` will be filled in all messages (except `FunctionResultMessage`).
Expand Down
141 changes: 141 additions & 0 deletions docs/chat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Chat

This page covers the `Chat` class which can be used to manage an ongoing conversation with an LLM chat model. To define a reusable LLM query template use the `@chatprompt` decorator instead, see [Chat Prompting](chat-prompting.md).

The `Chat` class represents an ongoing conversation with an LLM. It keeps track of the messages exchanged between the user and the model and allows submitting the conversation to the model to get a response.

## Basic Usage

```python
from magentic import Chat, OpenaiChatModel, UserMessage

# Create a new Chat instance
chat = Chat(
messages=[UserMessage("Say hello")],
model=OpenaiChatModel("gpt-4o"),
)

# Append a new user message
chat = chat.add_user_message("Actually, say goodbye!")
print(chat.messages)
# [UserMessage('Say hello'), UserMessage('Actually, say goodbye!')]

# Submit the chat to the LLM to get a response
chat = chat.submit()
print(chat.last_message.content)
# 'Hello! Just kidding—goodbye!'
```

Note that all methods of `Chat` return a new `Chat` instance with the updated messages. This allows branching the conversation and keeping track of multiple conversation paths.

The following methods are available to manually add messages to the chat by providing just the content of the message:

- `add_system_message`: Adds a system message to the chat.
- `add_user_message`: Adds a user message to the chat.
- `add_assistant_message`: Adds an assistant message to the chat.

There is also a generic `add_message` method to add a `Message` object to the chat. And the `submit` method is used to submit the chat to the LLM model which adds an `AssistantMessage` to the chat containing the model's response.

## Function Calling

Function calling can be done with the `Chat` class by providing the list of functions when creating the instance, similar to the `@chatprompt` decorator. Similarly, structured outputs can be returned by setting the `output_types` parameter.

If the last message in the chat is an `AssistantMessage` containing a `FunctionCall` or `ParallelFunctionCall`, calling the `exec_function_call` method will execute the function call(s) and append the result(s) to the chat. Then, if needed, the chat can be submitted to the LLM again to get another response.

```python hl_lines="23-25"
from magentic import (
AssistantMessage,
Chat,
FunctionCall,
OpenaiChatModel,
UserMessage,
)


def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location"""
# Pretend to query an API
return {"temperature": "72", "forecast": ["sunny", "windy"]}


chat = Chat(
messages=[UserMessage("What's the weather like in Boston?")],
functions=[get_current_weather],
# `FunctionCall` must be in output_types to get `FunctionCall` outputs
output_types=[FunctionCall, str],
model=OpenaiChatModel("gpt-4o"),
)
chat = chat.submit()
print(chat.messages)
# [UserMessage("What's the weather like in Boston?"),
# AssistantMessage(FunctionCall(<function get_current_weather at 0x130a92160>, 'Boston'))]

# Execute the function call and append the result to the chat
chat = chat.exec_function_call()
print(chat.messages)
# [UserMessage("What's the weather like in Boston?"),
# AssistantMessage(FunctionCall(<function get_current_weather at 0x130a92160>, 'Boston')),
# FunctionResultMessage({'location': 'Boston', 'temperature': '72', 'unit': 'fahrenheit', 'forecast': ['sunny', 'windy']},
# FunctionCall(<function get_current_weather at 0x130a92160>, 'Boston'))]

# Submit the chat again to get the final LLM response
chat = chat.submit()
print(chat.messages)
# [UserMessage("What's the weather like in Boston?"),
# AssistantMessage(FunctionCall(<function get_current_weather at 0x130a92160>, 'Boston')),
# FunctionResultMessage({'location': 'Boston', 'temperature': '72', 'unit': 'fahrenheit', 'forecast': ['sunny', 'windy']},
# FunctionCall(<function get_current_weather at 0x130a92160>, 'Boston')),
# AssistantMessage("The current weather in Boston is 72°F, and it's sunny with windy conditions.")]
```

## Streaming

Streaming types such as `StreamedStr`, `StreamedOutput`, and `Iterable[T]` can be provided in the `output_types` parameter. When the `.submit()` method is called, an `AssistantMessage` containing the streamed type will be appended to the chat immediately. This allows the streamed type to be accessed and streamed from. For more information on streaming types, see [Streaming](streaming.md).

```python
from magentic import Chat, UserMessage, StreamedStr

chat = Chat(
messages=[UserMessage("Tell me about the Golden Gate Bridge.")],
output_types=[StreamedStr],
)
chat = chat.submit()
print(type(chat.last_message.content))
# <class 'magentic.streaming.StreamedStr'>

for chunk in chat.last_message.content:
print(chunk, end="")
# (streamed) 'The Golden Gate Bridge is an iconic suspension bridge...
```

## Asyncio

The `Chat` class also support asynchronous usage through the following methods:

- `asubmit`: Asynchronously submit the chat to the LLM model.
- `aexec_function_call`: Asynchronously execute the function call in the chat. This is required to handle the `AsyncParallelFunctionCall` output type.

## Agent

A very basic form of an agent can be created by running a loop that submits the chat to the LLM and executes function calls until some stop condition is met.

```python
from magentic import Chat, FunctionCall, ParallelFunctionCall, UserMessage


def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location"""
# Pretend to query an API
return {"temperature": "72", "forecast": ["sunny", "windy"]}


chat = Chat(
messages=[UserMessage("What's the weather like in Boston?")],
functions=[get_current_weather],
output_types=[FunctionCall, str],
).submit()
while isinstance(chat.last_message.content, FunctionCall | ParallelFunctionCall):
chat = chat.exec_function_call().submit()
print(chat.last_message.content)
# 'The current weather in Boston is 72°F, with sunny and windy conditions.'
```
7 changes: 1 addition & 6 deletions docs/function-calling.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,7 @@ from magentic import prompt_chain
def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location"""
# Pretend to query an API
return {
"location": location,
"temperature": "72",
"unit": unit,
"forecast": ["sunny", "windy"],
}
return {"temperature": "72", "forecast": ["sunny", "windy"]}


@prompt_chain(
Expand Down
7 changes: 1 addition & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,7 @@ from magentic import prompt_chain
def get_current_weather(location, unit="fahrenheit"):
"""Get the current weather in a given location"""
# Pretend to query an API
return {
"location": location,
"temperature": "72",
"unit": unit,
"forecast": ["sunny", "windy"],
}
return {"temperature": "72", "forecast": ["sunny", "windy"]}


@prompt_chain(
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ nav:
- Overview: index.md
- structured-outputs.md
- chat-prompting.md
- chat.md
- function-calling.md
- formatting.md
- asyncio.md
Expand Down
1 change: 1 addition & 0 deletions src/magentic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._chat import Chat as Chat
from ._pydantic import ConfigDict as ConfigDict
from ._pydantic import with_config as with_config
from ._streamed_response import AsyncStreamedResponse as AsyncStreamedResponse
Expand Down
35 changes: 25 additions & 10 deletions src/magentic/chat.py → src/magentic/_chat.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import inspect
from collections.abc import Callable, Iterable, Sequence
from typing import Any, ParamSpec, TypeVar
from typing import Any, ParamSpec

from typing_extensions import Self, deprecated

from magentic.backend import get_chat_model
from magentic.chat_model.base import ChatModel
from magentic.chat_model.message import (
AssistantMessage,
FunctionResultMessage,
Message,
SystemMessage,
UserMessage,
UserMessageContentBlock,
)
from magentic.function_call import (
AsyncParallelFunctionCall,
Expand All @@ -19,8 +23,6 @@
from magentic.streaming import async_iter, azip

P = ParamSpec("P")
# TODO: Use `Self` from typing_extensions
Self = TypeVar("Self", bound="Chat")


class Chat:
Expand All @@ -39,6 +41,7 @@ class Chat:
def __init__(
self,
messages: Sequence[Message[Any]] | None = None,
*,
functions: Iterable[Callable[..., Any]] | None = None,
output_types: Iterable[type[Any]] | None = None,
model: ChatModel | None = None,
Expand All @@ -49,6 +52,10 @@ def __init__(
self._model = model

@classmethod
@deprecated(
"Chat.from_prompt will be removed in a future version."
" Instead, use the regular init method, `Chat(messages, functions, output_types, model)`."
)
def from_prompt(
cls: type[Self],
prompt: BasePromptFunction[P, Any],
Expand All @@ -75,7 +82,7 @@ def last_message(self) -> Message[Any]:
def model(self) -> ChatModel:
return self._model or get_chat_model()

def add_message(self: Self, message: Message[Any]) -> Self:
def add_message(self, message: Message[Any]) -> Self:
"""Add a message to the chat."""
return type(self)(
messages=[*self._messages, message],
Expand All @@ -84,15 +91,22 @@ def add_message(self: Self, message: Message[Any]) -> Self:
model=self._model, # Keep `None` value if unset
)

def add_user_message(self: Self, content: str) -> Self:
def add_system_message(self, content: str) -> Self:
"""Add a system message to the chat."""
return self.add_message(SystemMessage(content=content))

def add_user_message(
self, content: str | Sequence[str | UserMessageContentBlock]
) -> Self:
"""Add a user message to the chat."""
return self.add_message(UserMessage(content=content))

def add_assistant_message(self: Self, content: Any) -> Self:
def add_assistant_message(self, content: Any) -> Self:
"""Add an assistant message to the chat."""
return self.add_message(AssistantMessage(content=content))

def submit(self: Self) -> Self:
# TODO: Allow restricting functions and/or output types here
def submit(self) -> Self:
"""Request an LLM message to be added to the chat."""
output_message: AssistantMessage[Any] = self.model.complete(
messages=self._messages,
Expand All @@ -101,7 +115,7 @@ def submit(self: Self) -> Self:
)
return self.add_message(output_message)

async def asubmit(self: Self) -> Self:
async def asubmit(self) -> Self:
"""Async version of `submit`."""
output_message: AssistantMessage[Any] = await self.model.acomplete(
messages=self._messages,
Expand All @@ -110,7 +124,8 @@ async def asubmit(self: Self) -> Self:
)
return self.add_message(output_message)

def exec_function_call(self: Self) -> Self:
# TODO: Add optional error handling to this method, with param to toggle
def exec_function_call(self) -> Self:
"""If the last message is a function call, execute it and add the result."""
if isinstance(self.last_message.content, FunctionCall):
function_call = self.last_message.content
Expand All @@ -133,7 +148,7 @@ def exec_function_call(self: Self) -> Self:
msg = "Last message is not a function call."
raise TypeError(msg)

async def aexec_function_call(self: Self) -> Self:
async def aexec_function_call(self) -> Self:
"""Async version of `exec_function_call`."""
if isinstance(self.last_message.content, FunctionCall):
function_call = self.last_message.content
Expand Down
23 changes: 19 additions & 4 deletions src/magentic/prompt_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from functools import wraps
from typing import Any, ParamSpec, TypeVar, cast

from magentic.chat import Chat
from magentic._chat import Chat
from magentic.chat_model.base import ChatModel
from magentic.chat_model.message import UserMessage
from magentic.function_call import FunctionCall
from magentic.logger import logfire
from magentic.prompt_function import AsyncPromptFunction, PromptFunction
Expand Down Expand Up @@ -45,8 +46,15 @@ async def awrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
f"Calling async prompt-chain {func.__name__}",
**func_signature.bind(*args, **kwargs).arguments,
):
chat = await Chat.from_prompt(
async_prompt_function, *args, **kwargs
chat = await Chat(
messages=[
UserMessage(
content=async_prompt_function.format(*args, **kwargs)
)
],
functions=async_prompt_function.functions,
output_types=async_prompt_function.return_types,
model=async_prompt_function._model, # Keep `None` value if unset
).asubmit()
num_calls = 0
while isinstance(chat.last_message.content, FunctionCall):
Expand Down Expand Up @@ -79,7 +87,14 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
f"Calling prompt-chain {func.__name__}",
**func_signature.bind(*args, **kwargs).arguments,
):
chat = Chat.from_prompt(prompt_function, *args, **kwargs).submit()
chat = Chat(
messages=[
UserMessage(content=prompt_function.format(*args, **kwargs))
],
functions=prompt_function.functions,
output_types=prompt_function.return_types,
model=prompt_function._model, # Keep `None` value if unset
).submit()
num_calls = 0
while isinstance(chat.last_message.content, FunctionCall):
if max_calls is not None and num_calls >= max_calls:
Expand Down
16 changes: 1 addition & 15 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from magentic.chat import Chat
from magentic._chat import Chat
from magentic.chat_model.message import (
AssistantMessage,
FunctionResultMessage,
Expand All @@ -13,26 +13,12 @@
FunctionCall,
ParallelFunctionCall,
)
from magentic.prompt_function import prompt
from magentic.streaming import async_iter

if TYPE_CHECKING:
from collections.abc import Awaitable


def test_chat_from_prompt():
"""Test creating a chat from a prompt function."""

def plus(a: int, b: int) -> int:
return a + b

@prompt("What is {a} plus {b}?", functions=[plus])
def add_text_numbers(a: str, b: str) -> int: ...

chat = Chat.from_prompt(add_text_numbers, "one", "two")
assert chat.messages == [UserMessage(content="What is one plus two?")]


def test_chat_add_message():
chat1 = Chat()
chat2 = chat1.add_message(UserMessage(content="Hello"))
Expand Down

0 comments on commit 7a46010

Please sign in to comment.