From 2d823c03b2009aa35b2c8ed3ba96247f16ae2b23 Mon Sep 17 00:00:00 2001 From: Michael Struwig Date: Tue, 19 Dec 2023 16:19:44 +0200 Subject: [PATCH 1/7] WIP: Add stop bindings to @prompt decorator. --- src/magentic/chat_model/base.py | 5 +++++ src/magentic/chat_model/openai_chat_model.py | 4 ++++ src/magentic/prompt_function.py | 10 ++++++++++ 3 files changed, 19 insertions(+) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index cead038a..31b1cfb6 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -32,6 +32,7 @@ def complete( self, messages: Iterable[Message[Any]], functions: None = ..., + stop: list[str] | None = None, output_types: None = ..., ) -> AssistantMessage[str]: ... @@ -42,6 +43,7 @@ def complete( self, messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], + stop: list[str] | None = None, output_types: None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -52,6 +54,7 @@ def complete( self, messages: Iterable[Message[Any]], functions: None = ..., + stop: list[str] | None = None, output_types: Iterable[type[R]] = ..., ) -> AssistantMessage[R]: ... @@ -62,6 +65,7 @@ def complete( self, messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], + stop: list[str], output_types: Iterable[type[R]], ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -71,6 +75,7 @@ def complete( self, messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, + stop: list[str] | None = None, output_types: Iterable[type[R | str]] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 2c0f6d4b..d4bb31b2 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -93,6 +93,7 @@ def openai_chatcompletion_create( max_tokens: int | None = None, seed: int | None = None, temperature: float | None = None, + stop: list[str] | None = None, functions: list[dict[str, Any]] | None = None, function_call: Literal["auto", "none"] | dict[str, Any] | None = None, ) -> Iterator[ChatCompletionChunk]: @@ -123,6 +124,7 @@ def openai_chatcompletion_create( seed=seed, temperature=temperature, stream=True, + stop=stop, **kwargs, ) return response @@ -266,6 +268,7 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + stop: Iterable[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] @@ -297,6 +300,7 @@ def complete( max_tokens=self.max_tokens, seed=self.seed, temperature=self.temperature, + stop=stop, functions=openai_functions, function_call=( {"name": openai_functions[0]["name"]} diff --git a/src/magentic/prompt_function.py b/src/magentic/prompt_function.py index bcdb7e06..1a718da0 100644 --- a/src/magentic/prompt_function.py +++ b/src/magentic/prompt_function.py @@ -1,3 +1,4 @@ +import copy import inspect from functools import update_wrapper from typing import ( @@ -36,6 +37,7 @@ def __init__( parameters: Sequence[inspect.Parameter], return_type: type[R], functions: list[Callable[..., Any]] | None = None, + stop: list[str] | None = None, model: ChatModel | None = None, ): self._signature = inspect.Signature( @@ -44,6 +46,7 @@ def __init__( ) self._template = template self._functions = functions or [] + self._stop = stop self._model = model self._return_types = [ @@ -56,6 +59,10 @@ def __init__( def functions(self) -> list[Callable[..., Any]]: return self._functions.copy() + @property + def stop(self) -> list[str] | None: + return copy.copy(self._stop) + @property def model(self) -> ChatModel: return self._model or get_chat_model() @@ -78,6 +85,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: """Query the LLM with the formatted prompt template.""" message = self.model.complete( messages=[UserMessage(content=self.format(*args, **kwargs))], + stop=self._stop, functions=self._functions, output_types=self._return_types, ) @@ -116,6 +124,7 @@ def __call__(self, func: Callable[P, R]) -> PromptFunction[P, R]: def prompt( template: str, functions: list[Callable[..., Any]] | None = None, + stop: list[str] | None = None, model: ChatModel | None = None, ) -> PromptDecorator: """Convert a function into an LLM prompt template. @@ -155,6 +164,7 @@ def decorator( parameters=list(func_signature.parameters.values()), return_type=func_signature.return_annotation, functions=functions, + stop=stop, model=model, ) return cast(PromptFunction[P, R], update_wrapper(prompt_function, func)) # type: ignore[redundant-cast] From 41a3a5081d573963c5885ef778b4088e16e35001 Mon Sep 17 00:00:00 2001 From: Michael Struwig Date: Thu, 21 Dec 2023 12:30:06 +0200 Subject: [PATCH 2/7] WIP: Add stop bindings to chat prompt. --- src/magentic/chatprompt.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/magentic/chatprompt.py b/src/magentic/chatprompt.py index 86b42360..3e5716a8 100644 --- a/src/magentic/chatprompt.py +++ b/src/magentic/chatprompt.py @@ -45,6 +45,7 @@ def __init__( parameters: Sequence[inspect.Parameter], return_type: type[R], functions: list[Callable[..., Any]] | None = None, + stop: list[str] | None = None, model: ChatModel | None = None, ): self._signature = inspect.Signature( @@ -54,6 +55,7 @@ def __init__( self._messages = messages self._functions = functions or [] self._model = model + self._stop = stop self._return_types = [ type_ @@ -97,6 +99,7 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: message = self.model.complete( messages=self.format(*args, **kwargs), functions=self._functions, + stop=self._stop, output_types=self._return_types, ) return cast(R, message.content) @@ -136,6 +139,7 @@ def __call__(self, func: Callable[P, R]) -> ChatPromptFunction[P, R]: def chatprompt( *messages: Message[Any], functions: list[Callable[..., Any]] | None = None, + stop: list[str] | None = None, model: ChatModel | None = None, ) -> ChatPromptDecorator: """Convert a function into an LLM chat prompt template. @@ -194,6 +198,7 @@ def decorator( parameters=list(func_signature.parameters.values()), return_type=func_signature.return_annotation, functions=functions, + stop=stop, model=model, ) return cast(ChatPromptFunction[P, R], update_wrapper(prompt_function, func)) # type: ignore[redundant-cast] From 9efa88a6aeb63f1debfc43e66083f31a7dee2def Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 7 Jan 2024 17:06:24 -0800 Subject: [PATCH 3/7] Put stop param alphabetical --- src/magentic/chat_model/openai_chat_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index d4bb31b2..a8de1f6a 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -92,8 +92,8 @@ def openai_chatcompletion_create( messages: list[ChatCompletionMessageParam], max_tokens: int | None = None, seed: int | None = None, - temperature: float | None = None, stop: list[str] | None = None, + temperature: float | None = None, functions: list[dict[str, Any]] | None = None, function_call: Literal["auto", "none"] | dict[str, Any] | None = None, ) -> Iterator[ChatCompletionChunk]: @@ -122,9 +122,9 @@ def openai_chatcompletion_create( messages=messages, max_tokens=max_tokens, seed=seed, - temperature=temperature, - stream=True, stop=stop, + stream=True, + temperature=temperature, **kwargs, ) return response @@ -299,8 +299,8 @@ def complete( messages=[message_to_openai_message(m) for m in messages], max_tokens=self.max_tokens, seed=self.seed, - temperature=self.temperature, stop=stop, + temperature=self.temperature, functions=openai_functions, function_call=( {"name": openai_functions[0]["name"]} From a4fef1fe9a50785eaf30d132ef14cfd0bd035ab2 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 7 Jan 2024 17:16:26 -0800 Subject: [PATCH 4/7] Make stop kwarg-only in ChatModel.complete --- src/magentic/chat_model/base.py | 15 ++++++++++----- src/magentic/chat_model/openai_chat_model.py | 11 ++++++++++- src/magentic/chatprompt.py | 4 ++-- src/magentic/prompt_function.py | 2 +- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index 31b1cfb6..d07aa9cd 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -32,8 +32,9 @@ def complete( self, messages: Iterable[Message[Any]], functions: None = ..., - stop: list[str] | None = None, output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -43,8 +44,9 @@ def complete( self, messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], - stop: list[str] | None = None, output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -54,8 +56,9 @@ def complete( self, messages: Iterable[Message[Any]], functions: None = ..., - stop: list[str] | None = None, output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -65,8 +68,9 @@ def complete( self, messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], - stop: list[str], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -75,8 +79,9 @@ def complete( self, messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, - stop: list[str] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index a8de1f6a..e6d245fc 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -233,6 +233,8 @@ def complete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -242,6 +244,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -251,6 +255,8 @@ def complete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -260,6 +266,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -268,7 +276,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, - stop: Iterable[str] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] diff --git a/src/magentic/chatprompt.py b/src/magentic/chatprompt.py index 3e5716a8..2e22168e 100644 --- a/src/magentic/chatprompt.py +++ b/src/magentic/chatprompt.py @@ -54,8 +54,8 @@ def __init__( ) self._messages = messages self._functions = functions or [] - self._model = model self._stop = stop + self._model = model self._return_types = [ type_ @@ -99,8 +99,8 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: message = self.model.complete( messages=self.format(*args, **kwargs), functions=self._functions, - stop=self._stop, output_types=self._return_types, + stop=self._stop, ) return cast(R, message.content) diff --git a/src/magentic/prompt_function.py b/src/magentic/prompt_function.py index 1a718da0..7755da6e 100644 --- a/src/magentic/prompt_function.py +++ b/src/magentic/prompt_function.py @@ -85,9 +85,9 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: """Query the LLM with the formatted prompt template.""" message = self.model.complete( messages=[UserMessage(content=self.format(*args, **kwargs))], - stop=self._stop, functions=self._functions, output_types=self._return_types, + stop=self._stop, ) return cast(R, message.content) From 65b6fe7905f399d9afee96bd804c09074b5a18b9 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 7 Jan 2024 18:42:48 -0800 Subject: [PATCH 5/7] Add stop param for async code --- src/magentic/chat_model/base.py | 10 ++++++++++ src/magentic/chat_model/openai_chat_model.py | 13 +++++++++++++ src/magentic/chatprompt.py | 2 ++ src/magentic/prompt_function.py | 2 ++ 4 files changed, 27 insertions(+) diff --git a/src/magentic/chat_model/base.py b/src/magentic/chat_model/base.py index d07aa9cd..a986aca9 100644 --- a/src/magentic/chat_model/base.py +++ b/src/magentic/chat_model/base.py @@ -97,6 +97,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -107,6 +109,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -117,6 +121,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -127,6 +133,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -136,6 +144,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index e6d245fc..9ce14cdc 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -138,6 +138,7 @@ async def openai_chatcompletion_acreate( messages: list[ChatCompletionMessageParam], max_tokens: int | None = None, seed: int | None = None, + stop: list[str] | None = None, temperature: float | None = None, functions: list[dict[str, Any]] | None = None, function_call: Literal["auto", "none"] | dict[str, Any] | None = None, @@ -166,6 +167,7 @@ async def openai_chatcompletion_acreate( messages=messages, max_tokens=max_tokens, seed=seed, + stop=stop, temperature=temperature, stream=True, **kwargs, @@ -373,6 +375,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -382,6 +386,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -391,6 +397,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -400,6 +408,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -408,6 +418,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] @@ -438,6 +450,7 @@ async def acomplete( messages=[message_to_openai_message(m) for m in messages], max_tokens=self.max_tokens, seed=self.seed, + stop=stop, temperature=self.temperature, functions=openai_functions, function_call=( diff --git a/src/magentic/chatprompt.py b/src/magentic/chatprompt.py index 2e22168e..630b9a67 100644 --- a/src/magentic/chatprompt.py +++ b/src/magentic/chatprompt.py @@ -114,6 +114,7 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: messages=self.format(*args, **kwargs), functions=self._functions, output_types=self._return_types, + stop=self._stop, ) return cast(R, message.content) @@ -188,6 +189,7 @@ def decorator( parameters=list(func_signature.parameters.values()), return_type=func_signature.return_annotation, functions=functions, + stop=stop, model=model, ) async_prompt_function = update_wrapper(async_prompt_function, func) diff --git a/src/magentic/prompt_function.py b/src/magentic/prompt_function.py index 7755da6e..49fd3cae 100644 --- a/src/magentic/prompt_function.py +++ b/src/magentic/prompt_function.py @@ -101,6 +101,7 @@ async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: messages=[UserMessage(content=self.format(*args, **kwargs))], functions=self._functions, output_types=self._return_types, + stop=self._stop, ) return cast(R, message.content) @@ -154,6 +155,7 @@ def decorator( parameters=list(func_signature.parameters.values()), return_type=func_signature.return_annotation, functions=functions, + stop=stop, model=model, ) async_prompt_function = update_wrapper(async_prompt_function, func) From 979aa478d7178b4bf3532418823d6d7f30e1913e Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 7 Jan 2024 18:48:51 -0800 Subject: [PATCH 6/7] Add stop param to LiteLlmChatModel --- src/magentic/chat_model/litellm_chat_model.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/magentic/chat_model/litellm_chat_model.py b/src/magentic/chat_model/litellm_chat_model.py index 8faedae0..13dc7260 100644 --- a/src/magentic/chat_model/litellm_chat_model.py +++ b/src/magentic/chat_model/litellm_chat_model.py @@ -38,6 +38,7 @@ def litellm_completion( messages: list[ChatCompletionMessageParam], api_base: str | None = None, max_tokens: int | None = None, + stop: list[str] | None = None, temperature: float | None = None, functions: list[dict[str, Any]] | None = None, function_call: Literal["auto", "none"] | dict[str, Any] | None = None, @@ -60,6 +61,7 @@ def litellm_completion( response: CustomStreamWrapper = litellm.completion( # type: ignore[no-untyped-call,unused-ignore] model=model, messages=messages, + stop=stop, stream=True, **kwargs, ) @@ -71,6 +73,7 @@ async def litellm_acompletion( messages: list[ChatCompletionMessageParam], api_base: str | None = None, max_tokens: int | None = None, + stop: list[str] | None = None, temperature: float | None = None, functions: list[dict[str, Any]] | None = None, function_call: Literal["auto", "none"] | dict[str, Any] | None = None, @@ -93,6 +96,7 @@ async def litellm_acompletion( response: AsyncIterator[ModelResponse] = await litellm.acompletion( # type: ignore[no-untyped-call,unused-ignore] model=model, messages=messages, + stop=stop, stream=True, **kwargs, ) @@ -141,6 +145,8 @@ def complete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -150,6 +156,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -159,6 +167,8 @@ def complete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -168,6 +178,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -176,6 +188,8 @@ def complete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] @@ -203,6 +217,7 @@ def complete( messages=[message_to_openai_message(m) for m in messages], api_base=self.api_base, max_tokens=self.max_tokens, + stop=stop, temperature=self.temperature, functions=openai_functions, function_call=( @@ -262,6 +277,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[str]: ... @@ -271,6 +288,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: None = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[str]: ... @@ -280,6 +299,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: None = ..., output_types: Iterable[type[R]] = ..., + *, + stop: list[str] | None = ..., ) -> AssistantMessage[R]: ... @@ -289,6 +310,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]], output_types: Iterable[type[R]], + *, + stop: list[str] | None = ..., ) -> AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R]: ... @@ -297,6 +320,8 @@ async def acomplete( messages: Iterable[Message[Any]], functions: Iterable[Callable[..., FuncR]] | None = None, output_types: Iterable[type[R | str]] | None = None, + *, + stop: list[str] | None = None, ) -> ( AssistantMessage[FunctionCall[FuncR]] | AssistantMessage[R] @@ -324,6 +349,7 @@ async def acomplete( messages=[message_to_openai_message(m) for m in messages], api_base=self.api_base, max_tokens=self.max_tokens, + stop=stop, temperature=self.temperature, functions=openai_functions, function_call=( From 98b63c414ac96b1cb21f326d8fd9664420fcaf71 Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Sun, 7 Jan 2024 22:28:41 -0800 Subject: [PATCH 7/7] Add tests for prompt functions calling ChatModel methods --- tests/test_chatprompt.py | 4 ++++ tests/test_prompt_function.py | 45 +++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/tests/test_chatprompt.py b/tests/test_chatprompt.py index 54fc0b85..fd651831 100644 --- a/tests/test_chatprompt.py +++ b/tests/test_chatprompt.py @@ -76,6 +76,7 @@ def test_chatpromptfunction_call(): @chatprompt( UserMessage("Hello {name}."), + stop=["stop"], model=mock_model, ) def say_hello(name: str) -> str | bool: @@ -87,6 +88,7 @@ def say_hello(name: str) -> str | bool: UserMessage("Hello World.") ] assert mock_model.complete.call_args.kwargs["output_types"] == [str, bool] + assert mock_model.complete.call_args.kwargs["stop"] == ["stop"] def test_chatprompt_decorator_docstring(): @@ -106,6 +108,7 @@ async def test_asyncchatpromptfunction_call(): @chatprompt( UserMessage("Hello {name}."), + stop=["stop"], model=mock_model, ) async def say_hello(name: str) -> str | bool: @@ -117,6 +120,7 @@ async def say_hello(name: str) -> str | bool: UserMessage("Hello World.") ] assert mock_model.acomplete.call_args.kwargs["output_types"] == [str, bool] + assert mock_model.acomplete.call_args.kwargs["stop"] == ["stop"] @pytest.mark.asyncio diff --git a/tests/test_prompt_function.py b/tests/test_prompt_function.py index fc075b44..0402b2cc 100644 --- a/tests/test_prompt_function.py +++ b/tests/test_prompt_function.py @@ -2,11 +2,13 @@ from inspect import getdoc from typing import Awaitable +from unittest.mock import AsyncMock, Mock import pytest from pydantic import BaseModel from magentic.chat_model.base import StructuredOutputError +from magentic.chat_model.message import AssistantMessage, UserMessage from magentic.chat_model.openai_chat_model import OpenaiChatModel from magentic.function_call import FunctionCall from magentic.prompt_function import AsyncPromptFunction, PromptFunction, prompt @@ -22,6 +24,27 @@ def func(param: str) -> str: assert func.format("arg") == "Test arg." +def test_promptfunction_call(): + mock_model = Mock() + mock_model.complete.return_value = AssistantMessage(content="Hello!") + + @prompt( + "Hello {name}.", + stop=["stop"], + model=mock_model, + ) + def say_hello(name: str) -> str | bool: + ... + + assert say_hello("World") == "Hello!" + assert mock_model.complete.call_count == 1 + assert mock_model.complete.call_args.kwargs["messages"] == [ + UserMessage("Hello World.") + ] + assert mock_model.complete.call_args.kwargs["output_types"] == [str, bool] + assert mock_model.complete.call_args.kwargs["stop"] == ["stop"] + + @pytest.mark.openai def test_decorator_return_str(): @prompt("What is the capital of {country}? Name only. No punctuation.") @@ -132,6 +155,28 @@ def days_between(start_date: str, end_date: str) -> int: days_between("Jan 4th 2019", "Jul 3rd 2019") +@pytest.mark.asyncio +async def test_async_promptfunction_call(): + mock_model = AsyncMock() + mock_model.acomplete.return_value = AssistantMessage(content="Hello!") + + @prompt( + "Hello {name}.", + stop=["stop"], + model=mock_model, + ) + async def say_hello(name: str) -> str | bool: + ... + + assert await say_hello("World") == "Hello!" + assert mock_model.acomplete.call_count == 1 + assert mock_model.acomplete.call_args.kwargs["messages"] == [ + UserMessage("Hello World.") + ] + assert mock_model.acomplete.call_args.kwargs["output_types"] == [str, bool] + assert mock_model.acomplete.call_args.kwargs["stop"] == ["stop"] + + @pytest.mark.asyncio @pytest.mark.openai async def test_async_decorator_return_str():