Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support individual partial enum values #1228

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions instructor/dsl/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

T = TypeVar("T", bound=OpenAISchema)

class MissingToolCall(Exception):
def __init__(self, response: Any, *args: Any) -> None:
super().__init__(*args)
self.response = response


class ParallelBase:
def __init__(self, *models: type[OpenAISchema]):
Expand All @@ -37,6 +42,9 @@ def from_response(
#! We expect this from the OpenAISchema class, We should address
#! this with a protocol or an abstract class... @jxnlco
assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS"
if not response.choices[0].message.tool_calls:
raise MissingToolCall(response)

for tool_call in response.choices[0].message.tool_calls:
name = tool_call.function.name
arguments = tool_call.function.arguments
Expand Down
45 changes: 38 additions & 7 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
from __future__ import annotations

from jiter import from_json
from pydantic import BaseModel, create_model
from typing import Union
from pydantic import (
BaseModel,
ValidationError,
ValidatorFunctionWrapHandler,
create_model,
WrapValidator,
)
from typing import Union, Any, Annotated
import types
import sys
from pydantic.fields import FieldInfo
Expand Down Expand Up @@ -47,6 +53,20 @@
pass


class PartialValidator(WrapValidator):
def __init__(self, **kwargs: Any):
def validate_partial(
v: Any,
validator: ValidatorFunctionWrapHandler,
) -> Optional[Any]:

Check failure on line 61 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (UP007)

instructor/dsl/partial.py:61:14: UP007 Use `X | Y` for type annotations
try:
return validator(v)
except ValidationError:
return None

super().__init__(func=validate_partial, **kwargs)


def _process_generic_arg(
arg: Any,
make_fields_optional: bool = False,
Expand All @@ -55,13 +75,17 @@
if arg_origin is not None:
# Handle any nested generic type (Union, List, Dict, etc.)
nested_args = get_args(arg)
modified_nested_args = tuple(
modified_nested_args = [
_process_generic_arg(
t,
make_fields_optional=make_fields_optional,
)
for t in nested_args
)
]
if make_fields_optional and Partial in modified_nested_args:
modified_nested_args.append(PartialValidator())

modified_nested_args = tuple(modified_nested_args)
# Special handling for Union types (types.UnionType isn't subscriptable)
if arg_origin in UNION_ORIGINS:
return Union[modified_nested_args] # type: ignore
Expand Down Expand Up @@ -94,11 +118,10 @@
modified_args = tuple(
_process_generic_arg(arg, make_fields_optional=True) for arg in generic_args
)
tmp_annotation: Any = Optional[generic_base[modified_args]] if generic_base else None # type: ignore

Check failure on line 121 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.11)

Unnecessary "# type: ignore" comment (reportUnnecessaryTypeIgnoreComment)

# Reconstruct the generic type with modified arguments
tmp_field.annotation = (
Optional[generic_base[modified_args]] if generic_base else None
)
tmp_field.annotation = tmp_annotation
tmp_field.default = None
# If the field is a BaseModel, then recursively convert it's
# attributes to optionals.
Expand All @@ -109,6 +132,14 @@
tmp_field.annotation = Optional[field.annotation] # type:ignore
tmp_field.default = None

# If a field is annotated with Partial, add the PartialValidator which will
# return None if validation fails
if Partial in field.metadata:
tmp_field.annotation = Annotated[
tmp_field.annotation,

Check failure on line 139 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.11)

Type of "annotation" is partially unknown   Type of "annotation" is "type[Any] | type[Partial[BaseModel]] | type[None] | Unknown | None" (reportUnknownMemberType)

Check failure on line 139 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Pyright (macos-latest, 3.11)

Expected class but received "type[Any] | type[Partial[BaseModel]] | type[None] | Unknown | None"   "None" is not a class (reportGeneralTypeIssues)
PartialValidator(),
] # type:ignore

return tmp_field.annotation, tmp_field # type: ignore


Expand Down
66 changes: 64 additions & 2 deletions tests/dsl/test_partial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# type: ignore[all]
from pydantic import BaseModel, Field
from typing import Optional, Union
from uuid import UUID
from datetime import datetime
from pydantic import BaseModel, Field, ValidationError, validator

Check failure on line 4 in tests/dsl/test_partial.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (F401)

tests/dsl/test_partial.py:4:57: F401 `pydantic.validator` imported but unused
from typing import Optional, Union, Literal, Annotated
from instructor.dsl.partial import Partial, PartialLiteralMixin
import pytest
import instructor
Expand All @@ -24,6 +26,7 @@
class NestedA(BaseModel):
a: str
b: Optional[str]
c: Optional[Annotated[datetime, Partial]]


class NestedB(BaseModel):
Expand All @@ -39,6 +42,22 @@
c: NestedB


SampleEnum = Literal["a_value", "b_value", "c_value"]
SampleMixedEnum = Literal["a_value", "b_value", "c_value", 1, 2, 3]


class PartialEnums(BaseModel):
a: Annotated[Literal["a_value"], Partial]
b: Annotated[SampleEnum, Partial]
c: Annotated[SampleMixedEnum, Partial]
d: Annotated[Literal["a_value", 10], Partial]
e: Annotated[Literal["a_value"], Partial]
f: Literal["a_value"]
g: Annotated[UUID, Partial]
h: Optional[Annotated[datetime, Partial]]
i: Optional[NestedA]


def test_partial():
partial = Partial[SamplePartial]
assert partial.model_json_schema() == {
Expand Down Expand Up @@ -192,3 +211,46 @@
partial.get_partial_model().model_validate_json(
'{"a": [{"b": "b"}, {"d": "d"}], "b": [{"b": "b"}], "c": {"d": "d"}, "e": [1, "a"]}'
)


def test_partial_enums():
# Test that we can annotate enum values with `Partial` and support parsing
# partial values with the partial model
partial = Partial[PartialEnums]
partial_results = (
'{"a": "a_", "b": "b_", "c": "c_v", "d": 1, "e": "a_", "f": "a_value", "g": "1", "h": "", "i": {"c": ""}}'
)
partial_validated = partial.get_partial_model().model_validate_json(partial_results)

assert partial_validated.a is None
assert partial_validated.b is None
assert partial_validated.c is None
assert partial_validated.d is None
assert partial_validated.e is None
assert partial_validated.f == "a_value"
assert partial_validated.g is None
assert partial_validated.h is None
assert partial_validated.i is not None
assert partial_validated.i.c is None


with pytest.raises(ValidationError):
partial.model_validate_json(partial_results)

with pytest.raises(ValidationError):
# "f" is not marked as a partil enum
partial.get_partial_model().model_validate_json('{"f": "a_"}')

resolved_enum_partial_results = (
'{"a": "a_value", "b": "b_value", "c": "c_v", "d": 10, "g": "123e4567-e89b-12d3-a456-426655440000", "h": "2024-01-01T00:00:00"}'
)
resolved_enum_partial_validated = partial.get_partial_model().model_validate_json(
resolved_enum_partial_results
)
assert resolved_enum_partial_validated.a == "a_value"
assert resolved_enum_partial_validated.b == "b_value"
# this value still isn't fully resolved
assert resolved_enum_partial_validated.c is None
assert resolved_enum_partial_validated.d == 10
assert resolved_enum_partial_validated.g == UUID("123e4567-e89b-12d3-a456-426655440000")
assert resolved_enum_partial_validated.h == datetime(2024, 1, 1)
Loading