diff --git a/ragna/core/_components.py b/ragna/core/_components.py index d237c1b8..d98932a7 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -4,7 +4,15 @@ import enum import functools import inspect -from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union +from typing import ( + AsyncIterable, + AsyncIterator, + Iterator, + Optional, + Type, + Union, + get_type_hints, +) import pydantic import pydantic.utils @@ -42,6 +50,16 @@ def __repr__(self) -> str: def _protocol_models( cls, ) -> dict[tuple[Type[Component], str], Type[pydantic.BaseModel]]: + # This method dynamically builds a pydantic.BaseModel for the extra parameters + # of each method that is listed in the __ragna_protocol_methods__ class + # variable. These models are used by ragna.core.Chat._unpack_chat_params to + # validate and distribute the **params passed by the user. + + # Walk up the MRO until we find the __ragna_protocol_methods__ variable. This is + # the indicator that we found the protocol class. We use this as a reference for + # which params of a protocol method are part of the protocol (think positional + # parameters) and which are requested by the concrete class (think keyword + # parameters). protocol_cls, protocol_methods = next( (cls_, cls_.__ragna_protocol_methods__) # type: ignore[attr-defined] for cls_ in cls.__mro__ @@ -49,23 +67,30 @@ def _protocol_models( ) models = {} for method_name in protocol_methods: + num_protocol_params = len( + inspect.signature(getattr(protocol_cls, method_name)).parameters + ) method = getattr(cls, method_name) - concrete_params = inspect.signature(method).parameters - protocol_params = inspect.signature( - getattr(protocol_cls, method_name) - ).parameters - extra_param_names = concrete_params.keys() - protocol_params.keys() - - models[(cls, method_name)] = pydantic.create_model( # type: ignore[call-overload] + params = iter(inspect.signature(method).parameters.values()) + annotations = get_type_hints(method) + # Skip over the protocol parameters in order for the model below to only + # comprise concrete parameters. + for _ in range(num_protocol_params): + next(params) + + models[(cls, method_name)] = pydantic.create_model( + # type: ignore[call-overload] f"{cls.__name__}.{method_name}", **{ - (param := concrete_params[param_name]).name: ( - param.annotation, - param.default - if param.default is not inspect.Parameter.empty - else ..., + param.name: ( + annotations[param.name], + ( + param.default + if param.default is not inspect.Parameter.empty + else ... + ), ) - for param_name in extra_param_names + for param in params }, ) return models diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 6cdff127..1490b673 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -1,8 +1,11 @@ from __future__ import annotations +import contextlib import datetime import inspect +import itertools import uuid +from collections import defaultdict from typing import ( Any, AsyncIterator, @@ -19,6 +22,7 @@ ) import pydantic +import pydantic_core from starlette.concurrency import iterate_in_threadpool, run_in_threadpool from ._components import Assistant, Component, Message, MessageRole, SourceStorage @@ -251,6 +255,15 @@ def _parse_documents(self, documents: Iterable[Any]) -> list[Document]: def _unpack_chat_params( self, params: dict[str, Any] ) -> dict[Callable, dict[str, Any]]: + # This method does two things: + # 1. Validate the **params against the signatures of the protocol methods of the + # used components. This makes sure that + # - No parameter is passed that isn't used by at least one component + # - No parameter is missing that is needed by at least one component + # - No parameter is passed in the wrong type + # 2. Prepare the distribution of the parameters to the protocol method that + # requested them. The actual distribution happens in self._run and + # self._run_gen, but is only a dictionary lookup by then. component_models = { getattr(component, name): model for component in [self.source_storage, self.assistant] @@ -258,20 +271,104 @@ def _unpack_chat_params( } ChatModel = merge_models( - str(self.params["chat_id"]), + f"{self.__module__}.{type(self).__name__}-{self.params['chat_id']}", SpecialChatParams, *component_models.values(), config=pydantic.ConfigDict(extra="forbid"), ) - chat_params = ChatModel.model_validate(params, strict=True).model_dump( - exclude_none=True - ) + with self._format_validation_error(ChatModel): + chat_model = ChatModel.model_validate(params, strict=True) + + chat_params = chat_model.model_dump(exclude_none=True) return { fn: model(**chat_params).model_dump() for fn, model in component_models.items() } + @contextlib.contextmanager + def _format_validation_error( + self, model_cls: type[pydantic.BaseModel] + ) -> Iterator[None]: + try: + yield + except pydantic.ValidationError as validation_error: + errors = defaultdict(list) + for error in validation_error.errors(): + errors[error["type"]].append(error) + + parts = [ + f"Validating the Chat parameters resulted in {validation_error.error_count()} errors:", + "", + ] + + def format_error( + error: pydantic_core.ErrorDetails, + *, + annotation: bool = False, + value: bool = False, + ) -> str: + param = cast(str, error["loc"][0]) + + formatted_error = f"- {param}" + if annotation: + annotation_ = cast( + type, model_cls.model_fields[param].annotation + ).__name__ + formatted_error += f": {annotation_}" + + if value: + value_ = error["input"] + formatted_error += ( + f" = {value_!r}" if annotation else f"={value_!r}" + ) + + return formatted_error + + if "extra_forbidden" in errors: + parts.extend( + [ + "The following parameters are unknown:", + "", + *[ + format_error(error, value=True) + for error in errors["extra_forbidden"] + ], + "", + ] + ) + + if "missing" in errors: + parts.extend( + [ + "The following parameters are missing:", + "", + *[ + format_error(error, annotation=True) + for error in errors["missing"] + ], + "", + ] + ) + + type_errors = ["string_type", "int_type", "float_type", "bool_type"] + if any(type_error in errors for type_error in type_errors): + parts.extend( + [ + "The following parameters have the wrong type:", + "", + *[ + format_error(error, annotation=True, value=True) + for error in itertools.chain.from_iterable( + errors[type_error] for type_error in type_errors + ) + ], + "", + ] + ) + + raise RagnaException("\n".join(parts)) + async def _run( self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any ) -> T: diff --git a/tests/core/test_rag.py b/tests/core/test_rag.py index 307f235e..558dbcfe 100644 --- a/tests/core/test_rag.py +++ b/tests/core/test_rag.py @@ -1,8 +1,7 @@ -import pydantic import pytest from ragna import Rag, assistants, source_storages -from ragna.core import LocalDocument +from ragna.core import Assistant, LocalDocument, RagnaException @pytest.fixture() @@ -14,20 +13,84 @@ def demo_document(tmp_path, request): class TestChat: - def chat(self, documents, **params): + def chat( + self, + documents, + source_storage=source_storages.RagnaDemoSourceStorage, + assistant=assistants.RagnaDemoAssistant, + **params, + ): return Rag().chat( documents=documents, - source_storage=source_storages.RagnaDemoSourceStorage, - assistant=assistants.RagnaDemoAssistant, + source_storage=source_storage, + assistant=assistant, **params, ) - def test_extra_params(self, demo_document): - with pytest.raises(pydantic.ValidationError, match="not_supported_parameter"): + def test_params_validation_unknown(self, demo_document): + params = { + "bool_param": False, + "int_param": 1, + "float_param": 0.5, + "string_param": "arbitrary_value", + } + with pytest.raises(RagnaException, match="unknown") as exc_info: + self.chat(documents=[demo_document], **params) + + msg = str(exc_info.value) + for param, value in params.items(): + assert f"{param}={value!r}" in msg + + def test_params_validation_missing(self, demo_document): + class ValidationAssistant(Assistant): + def answer( + self, + prompt, + sources, + bool_param: bool, + int_param: int, + float_param: float, + string_param: str, + ): + pass + + with pytest.raises(RagnaException, match="missing") as exc_info: + self.chat(documents=[demo_document], assistant=ValidationAssistant) + + msg = str(exc_info.value) + for param, annotation in ValidationAssistant.answer.__annotations__.items(): + assert f"{param}: {annotation.__name__}" in msg + + def test_params_validation_wrong_type(self, demo_document): + class ValidationAssistant(Assistant): + def answer( + self, + prompt, + sources, + bool_param: bool, + int_param: int, + float_param: float, + string_param: str, + ): + pass + + params = { + "bool_param": 1, + "int_param": 0.5, + "float_param": "arbitrary_value", + "string_param": False, + } + + with pytest.raises(RagnaException, match="wrong type") as exc_info: self.chat( - documents=[demo_document], not_supported_parameter="arbitrary_value" + documents=[demo_document], assistant=ValidationAssistant, **params ) + msg = str(exc_info.value) + for param, value in params.items(): + annotation = ValidationAssistant.answer.__annotations__[param] + assert f"{param}: {annotation.__name__} = {value!r}" in msg + def test_document_path(self, demo_document): chat = self.chat(documents=[demo_document.path])