Skip to content

Commit

Permalink
refactor protocol model extraction to only check extra parameters (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Jul 2, 2024
1 parent a0bf68c commit 9776ec5
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 26 deletions.
53 changes: 39 additions & 14 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@
import enum
import functools
import inspect
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union
from typing import (
AsyncIterable,
AsyncIterator,
Iterator,
Optional,
Type,
Union,
get_type_hints,
)

import pydantic
import pydantic.utils
Expand Down Expand Up @@ -42,30 +50,47 @@ def __repr__(self) -> str:
def _protocol_models(
cls,
) -> dict[tuple[Type[Component], str], Type[pydantic.BaseModel]]:
# This method dynamically builds a pydantic.BaseModel for the extra parameters
# of each method that is listed in the __ragna_protocol_methods__ class
# variable. These models are used by ragna.core.Chat._unpack_chat_params to
# validate and distribute the **params passed by the user.

# Walk up the MRO until we find the __ragna_protocol_methods__ variable. This is
# the indicator that we found the protocol class. We use this as a reference for
# which params of a protocol method are part of the protocol (think positional
# parameters) and which are requested by the concrete class (think keyword
# parameters).
protocol_cls, protocol_methods = next(
(cls_, cls_.__ragna_protocol_methods__) # type: ignore[attr-defined]
for cls_ in cls.__mro__
if "__ragna_protocol_methods__" in cls_.__dict__
)
models = {}
for method_name in protocol_methods:
num_protocol_params = len(
inspect.signature(getattr(protocol_cls, method_name)).parameters
)
method = getattr(cls, method_name)
concrete_params = inspect.signature(method).parameters
protocol_params = inspect.signature(
getattr(protocol_cls, method_name)
).parameters
extra_param_names = concrete_params.keys() - protocol_params.keys()

models[(cls, method_name)] = pydantic.create_model( # type: ignore[call-overload]
params = iter(inspect.signature(method).parameters.values())
annotations = get_type_hints(method)
# Skip over the protocol parameters in order for the model below to only
# comprise concrete parameters.
for _ in range(num_protocol_params):
next(params)

models[(cls, method_name)] = pydantic.create_model(
# type: ignore[call-overload]
f"{cls.__name__}.{method_name}",
**{
(param := concrete_params[param_name]).name: (
param.annotation,
param.default
if param.default is not inspect.Parameter.empty
else ...,
param.name: (
annotations[param.name],
(
param.default
if param.default is not inspect.Parameter.empty
else ...
),
)
for param_name in extra_param_names
for param in params
},
)
return models
Expand Down
105 changes: 101 additions & 4 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import contextlib
import datetime
import inspect
import itertools
import uuid
from collections import defaultdict
from typing import (
Any,
AsyncIterator,
Expand All @@ -19,6 +22,7 @@
)

import pydantic
import pydantic_core
from starlette.concurrency import iterate_in_threadpool, run_in_threadpool

from ._components import Assistant, Component, Message, MessageRole, SourceStorage
Expand Down Expand Up @@ -251,27 +255,120 @@ def _parse_documents(self, documents: Iterable[Any]) -> list[Document]:
def _unpack_chat_params(
self, params: dict[str, Any]
) -> dict[Callable, dict[str, Any]]:
# This method does two things:
# 1. Validate the **params against the signatures of the protocol methods of the
# used components. This makes sure that
# - No parameter is passed that isn't used by at least one component
# - No parameter is missing that is needed by at least one component
# - No parameter is passed in the wrong type
# 2. Prepare the distribution of the parameters to the protocol method that
# requested them. The actual distribution happens in self._run and
# self._run_gen, but is only a dictionary lookup by then.
component_models = {
getattr(component, name): model
for component in [self.source_storage, self.assistant]
for (_, name), model in component._protocol_models().items()
}

ChatModel = merge_models(
str(self.params["chat_id"]),
f"{self.__module__}.{type(self).__name__}-{self.params['chat_id']}",
SpecialChatParams,
*component_models.values(),
config=pydantic.ConfigDict(extra="forbid"),
)

chat_params = ChatModel.model_validate(params, strict=True).model_dump(
exclude_none=True
)
with self._format_validation_error(ChatModel):
chat_model = ChatModel.model_validate(params, strict=True)

chat_params = chat_model.model_dump(exclude_none=True)
return {
fn: model(**chat_params).model_dump()
for fn, model in component_models.items()
}

@contextlib.contextmanager
def _format_validation_error(
self, model_cls: type[pydantic.BaseModel]
) -> Iterator[None]:
try:
yield
except pydantic.ValidationError as validation_error:
errors = defaultdict(list)
for error in validation_error.errors():
errors[error["type"]].append(error)

parts = [
f"Validating the Chat parameters resulted in {validation_error.error_count()} errors:",
"",
]

def format_error(
error: pydantic_core.ErrorDetails,
*,
annotation: bool = False,
value: bool = False,
) -> str:
param = cast(str, error["loc"][0])

formatted_error = f"- {param}"
if annotation:
annotation_ = cast(
type, model_cls.model_fields[param].annotation
).__name__
formatted_error += f": {annotation_}"

if value:
value_ = error["input"]
formatted_error += (
f" = {value_!r}" if annotation else f"={value_!r}"
)

return formatted_error

if "extra_forbidden" in errors:
parts.extend(
[
"The following parameters are unknown:",
"",
*[
format_error(error, value=True)
for error in errors["extra_forbidden"]
],
"",
]
)

if "missing" in errors:
parts.extend(
[
"The following parameters are missing:",
"",
*[
format_error(error, annotation=True)
for error in errors["missing"]
],
"",
]
)

type_errors = ["string_type", "int_type", "float_type", "bool_type"]
if any(type_error in errors for type_error in type_errors):
parts.extend(
[
"The following parameters have the wrong type:",
"",
*[
format_error(error, annotation=True, value=True)
for error in itertools.chain.from_iterable(
errors[type_error] for type_error in type_errors
)
],
"",
]
)

raise RagnaException("\n".join(parts))

async def _run(
self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any
) -> T:
Expand Down
79 changes: 71 additions & 8 deletions tests/core/test_rag.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pydantic
import pytest

from ragna import Rag, assistants, source_storages
from ragna.core import LocalDocument
from ragna.core import Assistant, LocalDocument, RagnaException


@pytest.fixture()
Expand All @@ -14,20 +13,84 @@ def demo_document(tmp_path, request):


class TestChat:
def chat(self, documents, **params):
def chat(
self,
documents,
source_storage=source_storages.RagnaDemoSourceStorage,
assistant=assistants.RagnaDemoAssistant,
**params,
):
return Rag().chat(
documents=documents,
source_storage=source_storages.RagnaDemoSourceStorage,
assistant=assistants.RagnaDemoAssistant,
source_storage=source_storage,
assistant=assistant,
**params,
)

def test_extra_params(self, demo_document):
with pytest.raises(pydantic.ValidationError, match="not_supported_parameter"):
def test_params_validation_unknown(self, demo_document):
params = {
"bool_param": False,
"int_param": 1,
"float_param": 0.5,
"string_param": "arbitrary_value",
}
with pytest.raises(RagnaException, match="unknown") as exc_info:
self.chat(documents=[demo_document], **params)

msg = str(exc_info.value)
for param, value in params.items():
assert f"{param}={value!r}" in msg

def test_params_validation_missing(self, demo_document):
class ValidationAssistant(Assistant):
def answer(
self,
prompt,
sources,
bool_param: bool,
int_param: int,
float_param: float,
string_param: str,
):
pass

with pytest.raises(RagnaException, match="missing") as exc_info:
self.chat(documents=[demo_document], assistant=ValidationAssistant)

msg = str(exc_info.value)
for param, annotation in ValidationAssistant.answer.__annotations__.items():
assert f"{param}: {annotation.__name__}" in msg

def test_params_validation_wrong_type(self, demo_document):
class ValidationAssistant(Assistant):
def answer(
self,
prompt,
sources,
bool_param: bool,
int_param: int,
float_param: float,
string_param: str,
):
pass

params = {
"bool_param": 1,
"int_param": 0.5,
"float_param": "arbitrary_value",
"string_param": False,
}

with pytest.raises(RagnaException, match="wrong type") as exc_info:
self.chat(
documents=[demo_document], not_supported_parameter="arbitrary_value"
documents=[demo_document], assistant=ValidationAssistant, **params
)

msg = str(exc_info.value)
for param, value in params.items():
annotation = ValidationAssistant.answer.__annotations__[param]
assert f"{param}: {annotation.__name__} = {value!r}" in msg

def test_document_path(self, demo_document):
chat = self.chat(documents=[demo_document.path])

Expand Down

0 comments on commit 9776ec5

Please sign in to comment.