diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index cead038a..a986aca9 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -33,6 +33,8 @@ def complete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -43,6 +45,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -53,6 +57,8 @@ def complete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -63,6 +69,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -72,6 +80,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] @@ -87,6 +97,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -97,6 +109,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -107,6 +121,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -117,6 +133,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -126,6 +144,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 8faedae0..13dc7260 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -38,6 +38,7 @@ def litellm_completion( messages: list[ChatCompletionMessageParam], api_base: str | None = None, max_tokens: int | None = None, + stop: list[str] | None = None, temperature: float | None = None, functions: list[dict[str, Any]] | None = None, function_call: Literal["auto", "none"] | dict[str, Any] | None = None, @@ -60,6 +61,7 @@ def litellm_completion( response: CustomStreamWrapper = litellm.completion( # type: ignore[no-untyped-call,unused-ignore] model=model, messages=messages, + stop=stop, stream=True, **kwargs, ) @@ -71,6 +73,7 @@ async def litellm_acompletion( messages: list[ChatCompletionMessageParam], api_base: str | None = None, max_tokens: int | None = None, + stop: list[str] | None = None, temperature: float | None = None, functions: list[dict[str, Any]] | None = None, function_call: Literal["auto", "none"] | dict[str, Any] | None = None, @@ -93,6 +96,7 @@ async def litellm_acompletion( response: AsyncIterator[ModelResponse] = await litellm.acompletion( # type: ignore[no-untyped-call,unused-ignore] model=model, messages=messages, + stop=stop, stream=True, **kwargs, ) @@ -141,6 +145,8 @@ def complete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -150,6 +156,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -159,6 +167,8 @@ def complete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -168,6 +178,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -176,6 +188,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] @@ -203,6 +217,7 @@ def complete( messages=[message_to_openai_message(m) for m in messages], api_base=self.api_base, max_tokens=self.max_tokens, + stop=stop, temperature=self.temperature, functions=openai_functions, function_call=( @@ -262,6 +277,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -271,6 +288,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -280,6 +299,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -289,6 +310,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -297,6 +320,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] @@ -324,6 +349,7 @@ async def acomplete( messages=[message_to_openai_message(m) for m in messages], api_base=self.api_base, max_tokens=self.max_tokens, + stop=stop, temperature=self.temperature, functions=openai_functions, function_call=( diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 2c0f6d4b..9ce14cdc 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -92,6 +92,7 @@ def openai_chatcompletion_create( messages: list[ChatCompletionMessageParam], max_tokens: int | None = None, seed: int | None = None, + stop: list[str] | None = None, temperature: float | None = None, functions: list[dict[str, Any]] | None = None, function_call: Literal["auto", "none"] | dict[str, Any] | None = None, @@ -121,8 +122,9 @@ def openai_chatcompletion_create( messages=messages, max_tokens=max_tokens, seed=seed, - temperature=temperature, + stop=stop, stream=True, + temperature=temperature, **kwargs, ) return response @@ -136,6 +138,7 @@ async def openai_chatcompletion_acreate( messages: list[ChatCompletionMessageParam], max_tokens: int | None = None, seed: int | None = None, + stop: list[str] | None = None, temperature: float | None = None, functions: list[dict[str, Any]] | None = None, function_call: Literal["auto", "none"] | dict[str, Any] | None = None, @@ -164,6 +167,7 @@ async def openai_chatcompletion_acreate( messages=messages, max_tokens=max_tokens, seed=seed, + stop=stop, temperature=temperature, stream=True, **kwargs, @@ -231,6 +235,8 @@ def complete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -240,6 +246,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -249,6 +257,8 @@ def complete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -258,6 +268,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -266,6 +278,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] @@ -296,6 +310,7 @@ def complete( messages=[message_to_openai_message(m) for m in messages], max_tokens=self.max_tokens, seed=self.seed, + stop=stop, temperature=self.temperature, functions=openai_functions, function_call=( @@ -360,6 +375,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -369,6 +386,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -378,6 +397,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -387,6 +408,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -395,6 +418,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] @@ -425,6 +450,7 @@ async def acomplete( messages=[message_to_openai_message(m) for m in messages], max_tokens=self.max_tokens, seed=self.seed, + stop=stop, temperature=self.temperature, functions=openai_functions, function_call=( diff --git a/src/magentic/chatprompt.py b/src/magentic/chatprompt.py index 86b42360..630b9a67 100644 --- a/src/magentic/chatprompt.py +++ b/src/magentic/chatprompt.py @@ -45,6 +45,7 @@ def __init__( parameters: Sequence[inspect.Parameter], return_type: type[R], functions: list[Callable[..., Any]] | None = None, + stop: list[str] | None = None, model: ChatModel | None = None, ): self._signature = inspect.Signature( @@ -53,6 +54,7 @@ def __init__( ) self._messages = messages self._functions = functions or [] + self._stop = stop self._model = model self._return_types = [ @@ -98,6 +100,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: messages=self.format(*args, **kwargs), functions=self._functions, output_types=self._return_types, + stop=self._stop, ) return cast(R, message.content) @@ -111,6 +114,7 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: messages=self.format(*args, **kwargs), functions=self._functions, output_types=self._return_types, + stop=self._stop, ) return cast(R, message.content) @@ -136,6 +140,7 @@ def __call__(self, func: Callable[P, R]) -> ChatPromptFunction[P, R]: def chatprompt( *messages: Message[Any], functions: list[Callable[..., Any]] | None = None, + stop: list[str] | None = None, model: ChatModel | None = None, ) -> ChatPromptDecorator: """Convert a function into an LLM chat prompt template. @@ -184,6 +189,7 @@ def decorator( parameters=list(func_signature.parameters.values()), return_type=func_signature.return_annotation, functions=functions, + stop=stop, model=model, ) async_prompt_function = update_wrapper(async_prompt_function, func) @@ -194,6 +200,7 @@ def decorator( parameters=list(func_signature.parameters.values()), return_type=func_signature.return_annotation, functions=functions, + stop=stop, model=model, ) return cast(ChatPromptFunction[P, R], update_wrapper(prompt_function, func)) # type: ignore[redundant-cast] diff --git a/src/magentic/prompt_function.py b/src/magentic/prompt_function.py index bcdb7e06..49fd3cae 100644 --- a/src/magentic/prompt_function.py +++ b/src/magentic/prompt_function.py @@ -1,3 +1,4 @@ +import copy import inspect from functools import update_wrapper from typing import ( @@ -36,6 +37,7 @@ def __init__( parameters: Sequence[inspect.Parameter], return_type: type[R], functions: list[Callable[..., Any]] | None = None, + stop: list[str] | None = None, model: ChatModel | None = None, ): self._signature = inspect.Signature( @@ -44,6 +46,7 @@ def __init__( ) self._template = template self._functions = functions or [] + self._stop = stop self._model = model self._return_types = [ @@ -56,6 +59,10 @@ def __init__( def functions(self) -> list[Callable[..., Any]]: return self._functions.copy() + @property + def stop(self) -> list[str] | None: + return copy.copy(self._stop) + @property def model(self) -> ChatModel: return self._model or get_chat_model() @@ -80,6 +87,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: messages=[UserMessage(content=self.format(*args, **kwargs))], functions=self._functions, output_types=self._return_types, + stop=self._stop, ) return cast(R, message.content) @@ -93,6 +101,7 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: messages=[UserMessage(content=self.format(*args, **kwargs))], functions=self._functions, output_types=self._return_types, + stop=self._stop, ) return cast(R, message.content) @@ -116,6 +125,7 @@ def __call__(self, func: Callable[P, R]) -> PromptFunction[P, R]: def prompt( template: str, functions: list[Callable[..., Any]] | None = None, + stop: list[str] | None = None, model: ChatModel | None = None, ) -> PromptDecorator: """Convert a function into an LLM prompt template. @@ -145,6 +155,7 @@ def decorator( parameters=list(func_signature.parameters.values()), return_type=func_signature.return_annotation, functions=functions, + stop=stop, model=model, ) async_prompt_function = update_wrapper(async_prompt_function, func) @@ -155,6 +166,7 @@ def decorator( parameters=list(func_signature.parameters.values()), return_type=func_signature.return_annotation, functions=functions, + stop=stop, model=model, ) return cast(PromptFunction[P, R], update_wrapper(prompt_function, func)) # type: ignore[redundant-cast] diff --git a/tests/test_chatprompt.py b/tests/test_chatprompt.py index 54fc0b85..fd651831 100644 --- a/tests/test_chatprompt.py +++ b/tests/test_chatprompt.py @@ -76,6 +76,7 @@ def test_chatpromptfunction_call(): @chatprompt( UserMessage("Hello {name}."), + stop=["stop"], model=mock_model, ) def say_hello(name: str) -> str | bool: @@ -87,6 +88,7 @@ def say_hello(name: str) -> str | bool: UserMessage("Hello World.") ] assert mock_model.complete.call_args.kwargs["output_types"] == [str, bool] + assert mock_model.complete.call_args.kwargs["stop"] == ["stop"] def test_chatprompt_decorator_docstring(): @@ -106,6 +108,7 @@ async def test_asyncchatpromptfunction_call(): @chatprompt( UserMessage("Hello {name}."), + stop=["stop"], model=mock_model, ) async def say_hello(name: str) -> str | bool: @@ -117,6 +120,7 @@ async def say_hello(name: str) -> str | bool: UserMessage("Hello World.") ] assert mock_model.acomplete.call_args.kwargs["output_types"] == [str, bool] + assert mock_model.acomplete.call_args.kwargs["stop"] == ["stop"] @pytest.mark.asyncio diff --git a/tests/test_prompt_function.py b/tests/test_prompt_function.py index fc075b44..0402b2cc 100644 --- a/tests/test_prompt_function.py +++ b/tests/test_prompt_function.py @@ -2,11 +2,13 @@ from inspect import getdoc from typing import Awaitable +from unittest.mock import AsyncMock, Mock import pytest from pydantic import BaseModel from magentic.chat_model.base import StructuredOutputError +from magentic.chat_model.message import AssistantMessage, UserMessage from magentic.chat_model.openai_chat_model import OpenaiChatModel from magentic.function_call import FunctionCall from magentic.prompt_function import AsyncPromptFunction, PromptFunction, prompt @@ -22,6 +24,27 @@ def func(param: str) -> str: assert func.format("arg") == "Test arg." +def test_promptfunction_call(): + mock_model = Mock() + mock_model.complete.return_value = AssistantMessage(content="Hello!") + + @prompt( + "Hello {name}.", + stop=["stop"], + model=mock_model, + ) + def say_hello(name: str) -> str | bool: + ... + + assert say_hello("World") == "Hello!" + assert mock_model.complete.call_count == 1 + assert mock_model.complete.call_args.kwargs["messages"] == [ + UserMessage("Hello World.") + ] + assert mock_model.complete.call_args.kwargs["output_types"] == [str, bool] + assert mock_model.complete.call_args.kwargs["stop"] == ["stop"] + + @pytest.mark.openai def test_decorator_return_str(): @prompt("What is the capital of {country}? Name only. No punctuation.") @@ -132,6 +155,28 @@ def days_between(start_date: str, end_date: str) -> int: days_between("Jan 4th 2019", "Jul 3rd 2019") +@pytest.mark.asyncio +async def test_async_promptfunction_call(): + mock_model = AsyncMock() + mock_model.acomplete.return_value = AssistantMessage(content="Hello!") + + @prompt( + "Hello {name}.", + stop=["stop"], + model=mock_model, + ) + async def say_hello(name: str) -> str | bool: + ... + + assert await say_hello("World") == "Hello!" + assert mock_model.acomplete.call_count == 1 + assert mock_model.acomplete.call_args.kwargs["messages"] == [ + UserMessage("Hello World.") + ] + assert mock_model.acomplete.call_args.kwargs["output_types"] == [str, bool] + assert mock_model.acomplete.call_args.kwargs["stop"] == ["stop"] + + @pytest.mark.asyncio @pytest.mark.openai async def test_async_decorator_return_str():