Skip to content

Commit

Permalink
Add tests for prompt functions calling ChatModel methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmpcollins committed Jan 8, 2024
1 parent 979aa47 commit 98b63c4
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tests/test_chatprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_chatpromptfunction_call():

@chatprompt(
UserMessage("Hello {name}."),
stop=["stop"],
model=mock_model,
)
def say_hello(name: str) -> str | bool:
Expand All @@ -87,6 +88,7 @@ def say_hello(name: str) -> str | bool:
UserMessage("Hello World.")
]
assert mock_model.complete.call_args.kwargs["output_types"] == [str, bool]
assert mock_model.complete.call_args.kwargs["stop"] == ["stop"]


def test_chatprompt_decorator_docstring():
Expand All @@ -106,6 +108,7 @@ async def test_asyncchatpromptfunction_call():

@chatprompt(
UserMessage("Hello {name}."),
stop=["stop"],
model=mock_model,
)
async def say_hello(name: str) -> str | bool:
Expand All @@ -117,6 +120,7 @@ async def say_hello(name: str) -> str | bool:
UserMessage("Hello World.")
]
assert mock_model.acomplete.call_args.kwargs["output_types"] == [str, bool]
assert mock_model.acomplete.call_args.kwargs["stop"] == ["stop"]


@pytest.mark.asyncio
Expand Down
45 changes: 45 additions & 0 deletions tests/test_prompt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from inspect import getdoc
from typing import Awaitable
from unittest.mock import AsyncMock, Mock

import pytest
from pydantic import BaseModel

from magentic.chat_model.base import StructuredOutputError
from magentic.chat_model.message import AssistantMessage, UserMessage
from magentic.chat_model.openai_chat_model import OpenaiChatModel
from magentic.function_call import FunctionCall
from magentic.prompt_function import AsyncPromptFunction, PromptFunction, prompt
Expand All @@ -22,6 +24,27 @@ def func(param: str) -> str:
assert func.format("arg") == "Test arg."


def test_promptfunction_call():
mock_model = Mock()
mock_model.complete.return_value = AssistantMessage(content="Hello!")

@prompt(
"Hello {name}.",
stop=["stop"],
model=mock_model,
)
def say_hello(name: str) -> str | bool:
...

assert say_hello("World") == "Hello!"
assert mock_model.complete.call_count == 1
assert mock_model.complete.call_args.kwargs["messages"] == [
UserMessage("Hello World.")
]
assert mock_model.complete.call_args.kwargs["output_types"] == [str, bool]
assert mock_model.complete.call_args.kwargs["stop"] == ["stop"]


@pytest.mark.openai
def test_decorator_return_str():
@prompt("What is the capital of {country}? Name only. No punctuation.")
Expand Down Expand Up @@ -132,6 +155,28 @@ def days_between(start_date: str, end_date: str) -> int:
days_between("Jan 4th 2019", "Jul 3rd 2019")


@pytest.mark.asyncio
async def test_async_promptfunction_call():
mock_model = AsyncMock()
mock_model.acomplete.return_value = AssistantMessage(content="Hello!")

@prompt(
"Hello {name}.",
stop=["stop"],
model=mock_model,
)
async def say_hello(name: str) -> str | bool:
...

assert await say_hello("World") == "Hello!"
assert mock_model.acomplete.call_count == 1
assert mock_model.acomplete.call_args.kwargs["messages"] == [
UserMessage("Hello World.")
]
assert mock_model.acomplete.call_args.kwargs["output_types"] == [str, bool]
assert mock_model.acomplete.call_args.kwargs["stop"] == ["stop"]


@pytest.mark.asyncio
@pytest.mark.openai
async def test_async_decorator_return_str():
Expand Down

0 comments on commit 98b63c4

Please sign in to comment.