Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not add system prompt part when dynamic system prompt function returns empty value #864

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ Running `mypy` on this will give the following output:

```bash
➤ uv run mypy type_mistakes.py
type_mistakes.py:18: error: Argument 1 to "system_prompt" of "Agent" has incompatible type "Callable[[RunContext[str]], str]"; expected "Callable[[RunContext[User]], str]" [arg-type]
type_mistakes.py:18: error: Argument 1 to "system_prompt" of "Agent" has incompatible type "Callable[[RunContext[str]], str | None]"; expected "Callable[[RunContext[User]], str | None]" [arg-type]
type_mistakes.py:28: error: Argument 1 to "foobar" has incompatible type "bool"; expected "bytes" [arg-type]
Found 2 errors in 1 file (checked 1 source file)
```
Expand All @@ -344,6 +344,7 @@ Generally, system prompts fall into two categories:
2. **Dynamic system prompts**: These depend in some way on context that isn't known until runtime, and should be defined via functions decorated with [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt].

You can add both to a single agent; they're appended in the order they're defined at runtime.
If a dynamic system prompt function returns `None`, or any empty value, its prompt part won't be added to the messages.

Here's an example using both types of system prompts:

Expand Down
9 changes: 6 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,18 @@ async def _reevaluate_dynamic_prompts(
# Look up the runner by its ref
if runner := self.system_prompt_dynamic_functions.get(part.dynamic_ref):
updated_part_content = await runner.run(run_context)
msg.parts[i] = _messages.SystemPromptPart(
updated_part_content, dynamic_ref=part.dynamic_ref
)
if updated_part_content:
msg.parts[i] = _messages.SystemPromptPart(
updated_part_content, dynamic_ref=part.dynamic_ref
)
Comment on lines +175 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong to me — with this change, if the function returns None (or '') we just don't update the part, rather than removing it. I think there are a few ways to do this more correctly:

  • option 1: remove the part from the msg.parts list (which would need to be done carefully..)
  • option 2: allow None in msg.parts, I don't love this
  • option 3: set the msg.parts[i] to be a SystemPromptPart with empty string as the value, and modify the handling in the Model implementations to ignore SystemPromptParts with empty content.


async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.ModelRequestPart]:
"""Build the initial messages for the conversation."""
messages: list[_messages.ModelRequestPart] = [_messages.SystemPromptPart(p) for p in self.system_prompts]
for sys_prompt_runner in self.system_prompt_functions:
prompt = await sys_prompt_runner.run(run_context)
if not prompt:
continue
if sys_prompt_runner.dynamic:
messages.append(_messages.SystemPromptPart(prompt, dynamic_ref=sys_prompt_runner.function.__qualname__))
else:
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ def __post_init__(self):
self._takes_ctx = len(inspect.signature(self.function).parameters) > 0
self._is_async = inspect.iscoroutinefunction(self.function)

async def run(self, run_context: RunContext[AgentDepsT]) -> str:
async def run(self, run_context: RunContext[AgentDepsT]) -> str | None:
if self._takes_ctx:
args = (run_context,)
else:
args = ()

if self._is_async:
function = cast(Callable[[Any], Awaitable[str]], self.function)
function = cast(Callable[[Any], Awaitable['str | None']], self.function)
return await function(*args)
else:
function = cast(Callable[[Any], str], self.function)
function = cast(Callable[[Any], 'str | None'], self.function)
return await _utils.run_in_executor(function, *args)
20 changes: 14 additions & 6 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,19 +645,27 @@ def override(

@overload
def system_prompt(
self, func: Callable[[RunContext[AgentDepsT]], str], /
) -> Callable[[RunContext[AgentDepsT]], str]: ...
self,
func: Callable[[RunContext[AgentDepsT]], str | None],
/,
) -> Callable[[RunContext[AgentDepsT]], str | None]: ...

@overload
def system_prompt(
self, func: Callable[[RunContext[AgentDepsT]], Awaitable[str]], /
) -> Callable[[RunContext[AgentDepsT]], Awaitable[str]]: ...
self,
func: Callable[[RunContext[AgentDepsT]], Awaitable[str | None]],
/,
) -> Callable[[RunContext[AgentDepsT]], Awaitable[str | None]]: ...

@overload
def system_prompt(self, func: Callable[[], str], /) -> Callable[[], str]: ...
def system_prompt(self, func: Callable[[], str | None], /) -> Callable[[], str | None]: ...

@overload
def system_prompt(self, func: Callable[[], Awaitable[str]], /) -> Callable[[], Awaitable[str]]: ...
def system_prompt(
self,
func: Callable[[], Awaitable[str | None]],
/,
) -> Callable[[], Awaitable[str | None]]: ...

@overload
def system_prompt(
Expand Down
8 changes: 4 additions & 4 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def replace_with(
"""Retrieval function param spec."""

SystemPromptFunc = Union[
Callable[[RunContext[AgentDepsT]], str],
Callable[[RunContext[AgentDepsT]], Awaitable[str]],
Callable[[], str],
Callable[[], Awaitable[str]],
Callable[[RunContext[AgentDepsT]], Union[str, None]],
Callable[[RunContext[AgentDepsT]], Awaitable[Union[str, None]]],
Callable[[], Union[str, None]],
Callable[[], Awaitable[Union[str, None]]],
]
"""A function that may or maybe not take `RunContext` as an argument, and may or may not be async.

Expand Down
13 changes: 12 additions & 1 deletion tests/models/test_model_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import AsyncIterator
from dataclasses import asdict
from datetime import timezone
from typing import Union

import pydantic_core
import pytest
Expand Down Expand Up @@ -312,10 +313,20 @@ def quz(x) -> str: # pyright: ignore[reportUnknownParameterType,reportMissingPa


@agent_all.system_prompt
def spam() -> str:
def spam() -> Union[str, None]:
return 'foobar'


@agent_all.system_prompt
def empty1() -> Union[str, None]:
return None


@agent_all.system_prompt
def empty2() -> Union[str, None]:
return ''


def test_register_all():
async def f(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
return ModelResponse(
Expand Down
10 changes: 9 additions & 1 deletion tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,9 +1330,13 @@ def test_dynamic_false_no_reevaluate():
dynamic_value = 'A'

@agent.system_prompt
async def func() -> str:
async def func() -> Union[str, None]:
return dynamic_value

@agent.system_prompt
async def empty_func() -> Union[str, None]:
return None

res = agent.run_sync('Hello')

assert res.all_messages() == snapshot(
Expand Down Expand Up @@ -1405,6 +1409,10 @@ def test_dynamic_true_reevaluate_system_prompt():
async def func():
return dynamic_value

@agent.system_prompt(dynamic=True)
async def empty_func():
return None

res = agent.run_sync('Hello')

assert res.all_messages() == snapshot(
Expand Down
12 changes: 9 additions & 3 deletions tests/typed_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class MyDeps:


@typed_agent.system_prompt
async def system_prompt_ok1(ctx: RunContext[MyDeps]) -> str:
async def system_prompt_ok1(ctx: RunContext[MyDeps]) -> Union[str, None]:
return f'{ctx.deps}'


Expand All @@ -32,9 +32,15 @@ def system_prompt_ok2() -> str:
return 'foobar'


@typed_agent.system_prompt
def system_prompt_ok3() -> Union[str, None]:
return None


# we have overloads for every possible signature of system_prompt, so the type of decorated functions is correct
assert_type(system_prompt_ok1, Callable[[RunContext[MyDeps]], Awaitable[str]])
assert_type(system_prompt_ok2, Callable[[], str])
assert_type(system_prompt_ok1, Callable[[RunContext[MyDeps]], Awaitable[Union[str, None]]])
assert_type(system_prompt_ok2, Callable[[], Union[str, None]])
assert_type(system_prompt_ok3, Callable[[], Union[str, None]])


@contextmanager
Expand Down