diff --git a/instructor/dsl/parallel.py b/instructor/dsl/parallel.py index a42dfa418..edc2c6162 100644 --- a/instructor/dsl/parallel.py +++ b/instructor/dsl/parallel.py @@ -16,6 +16,11 @@ T = TypeVar("T", bound=OpenAISchema) +class MissingToolCall(Exception): + def __init__(self, response: Any, *args: Any) -> None: + super().__init__(*args) + self.response = response + class ParallelBase: def __init__(self, *models: type[OpenAISchema]): @@ -37,6 +42,9 @@ def from_response( #! We expect this from the OpenAISchema class, We should address #! this with a protocol or an abstract class... @jxnlco assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS" + if not response.choices[0].message.tool_calls: + raise MissingToolCall(response) + for tool_call in response.choices[0].message.tool_calls: name = tool_call.function.name arguments = tool_call.function.arguments diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index 95dcb9bb0..00f935f0e 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -9,8 +9,14 @@ from __future__ import annotations from jiter import from_json -from pydantic import BaseModel, create_model -from typing import Union +from pydantic import ( + BaseModel, + ValidationError, + ValidatorFunctionWrapHandler, + create_model, + WrapValidator, +) +from typing import Union, Any, Annotated import types import sys from pydantic.fields import FieldInfo @@ -47,6 +53,20 @@ class PartialLiteralMixin: pass +class PartialValidator(WrapValidator): + def __init__(self, **kwargs: Any): + def validate_partial( + v: Any, + validator: ValidatorFunctionWrapHandler, + ) -> Optional[Any]: + try: + return validator(v) + except ValidationError: + return None + + super().__init__(func=validate_partial, **kwargs) + + def _process_generic_arg( arg: Any, make_fields_optional: bool = False, @@ -55,13 +75,17 @@ def _process_generic_arg( if arg_origin is not None: # Handle any nested generic type (Union, List, Dict, etc.) nested_args = get_args(arg) - modified_nested_args = tuple( + modified_nested_args = [ _process_generic_arg( t, make_fields_optional=make_fields_optional, ) for t in nested_args - ) + ] + if make_fields_optional and Partial in modified_nested_args: + modified_nested_args.append(PartialValidator()) + + modified_nested_args = tuple(modified_nested_args) # Special handling for Union types (types.UnionType isn't subscriptable) if arg_origin in UNION_ORIGINS: return Union[modified_nested_args] # type: ignore @@ -94,11 +118,10 @@ def _make_field_optional( modified_args = tuple( _process_generic_arg(arg, make_fields_optional=True) for arg in generic_args ) + tmp_annotation: Any = Optional[generic_base[modified_args]] if generic_base else None # type: ignore # Reconstruct the generic type with modified arguments - tmp_field.annotation = ( - Optional[generic_base[modified_args]] if generic_base else None - ) + tmp_field.annotation = tmp_annotation tmp_field.default = None # If the field is a BaseModel, then recursively convert it's # attributes to optionals. @@ -109,6 +132,14 @@ def _make_field_optional( tmp_field.annotation = Optional[field.annotation] # type:ignore tmp_field.default = None + # If a field is annotated with Partial, add the PartialValidator which will + # return None if validation fails + if Partial in field.metadata: + tmp_field.annotation = Annotated[ + tmp_field.annotation, + PartialValidator(), + ] # type:ignore + return tmp_field.annotation, tmp_field # type: ignore diff --git a/tests/dsl/test_partial.py b/tests/dsl/test_partial.py index 2241406de..c3727c32c 100644 --- a/tests/dsl/test_partial.py +++ b/tests/dsl/test_partial.py @@ -1,6 +1,8 @@ # type: ignore[all] -from pydantic import BaseModel, Field -from typing import Optional, Union +from uuid import UUID +from datetime import datetime +from pydantic import BaseModel, Field, ValidationError, validator +from typing import Optional, Union, Literal, Annotated from instructor.dsl.partial import Partial, PartialLiteralMixin import pytest import instructor @@ -24,6 +26,7 @@ class SamplePartial(BaseModel): class NestedA(BaseModel): a: str b: Optional[str] + c: Optional[Annotated[datetime, Partial]] class NestedB(BaseModel): @@ -39,6 +42,22 @@ class UnionWithNested(BaseModel): c: NestedB +SampleEnum = Literal["a_value", "b_value", "c_value"] +SampleMixedEnum = Literal["a_value", "b_value", "c_value", 1, 2, 3] + + +class PartialEnums(BaseModel): + a: Annotated[Literal["a_value"], Partial] + b: Annotated[SampleEnum, Partial] + c: Annotated[SampleMixedEnum, Partial] + d: Annotated[Literal["a_value", 10], Partial] + e: Annotated[Literal["a_value"], Partial] + f: Literal["a_value"] + g: Annotated[UUID, Partial] + h: Optional[Annotated[datetime, Partial]] + i: Optional[NestedA] + + def test_partial(): partial = Partial[SamplePartial] assert partial.model_json_schema() == { @@ -192,3 +211,46 @@ def test_union_with_nested(): partial.get_partial_model().model_validate_json( '{"a": [{"b": "b"}, {"d": "d"}], "b": [{"b": "b"}], "c": {"d": "d"}, "e": [1, "a"]}' ) + + +def test_partial_enums(): + # Test that we can annotate enum values with `Partial` and support parsing + # partial values with the partial model + partial = Partial[PartialEnums] + partial_results = ( + '{"a": "a_", "b": "b_", "c": "c_v", "d": 1, "e": "a_", "f": "a_value", "g": "1", "h": "", "i": {"c": ""}}' + ) + partial_validated = partial.get_partial_model().model_validate_json(partial_results) + + assert partial_validated.a is None + assert partial_validated.b is None + assert partial_validated.c is None + assert partial_validated.d is None + assert partial_validated.e is None + assert partial_validated.f == "a_value" + assert partial_validated.g is None + assert partial_validated.h is None + assert partial_validated.i is not None + assert partial_validated.i.c is None + + + with pytest.raises(ValidationError): + partial.model_validate_json(partial_results) + + with pytest.raises(ValidationError): + # "f" is not marked as a partil enum + partial.get_partial_model().model_validate_json('{"f": "a_"}') + + resolved_enum_partial_results = ( + '{"a": "a_value", "b": "b_value", "c": "c_v", "d": 10, "g": "123e4567-e89b-12d3-a456-426655440000", "h": "2024-01-01T00:00:00"}' + ) + resolved_enum_partial_validated = partial.get_partial_model().model_validate_json( + resolved_enum_partial_results + ) + assert resolved_enum_partial_validated.a == "a_value" + assert resolved_enum_partial_validated.b == "b_value" + # this value still isn't fully resolved + assert resolved_enum_partial_validated.c is None + assert resolved_enum_partial_validated.d == 10 + assert resolved_enum_partial_validated.g == UUID("123e4567-e89b-12d3-a456-426655440000") + assert resolved_enum_partial_validated.h == datetime(2024, 1, 1)