Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support individual partial enum values #1228

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
35 changes: 25 additions & 10 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from __future__ import annotations

from jiter import from_json
from pydantic import BaseModel, create_model
from typing import Union
from pydantic import BaseModel, create_model, BeforeValidator
from typing import Literal, Union, Any, Annotated
import types
import sys
from pydantic.fields import FieldInfo
Expand Down Expand Up @@ -47,6 +47,16 @@
pass


class PartialLiteralValidator(BeforeValidator):
def __init__(self, literal_type: Any, **kwargs: Any):
def validate_literal(v: Any) -> Optional[Any]:

Check failure on line 52 in instructor/dsl/partial.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (UP007)

instructor/dsl/partial.py:52:41: UP007 Use `X | Y` for type annotations
if v in get_args(literal_type):
return v
return None

super().__init__(func=validate_literal, **kwargs)


def _process_generic_arg(
arg: Any,
make_fields_optional: bool = False,
Expand Down Expand Up @@ -91,15 +101,20 @@
generic_base = get_origin(annotation)
generic_args = get_args(annotation)

modified_args = tuple(
_process_generic_arg(arg, make_fields_optional=True) for arg in generic_args
)
if generic_base is Literal and Partial in field.metadata:
literal_types: set[type[Any]] = {type(arg) for arg in generic_args}
tmp_field.annotation = Annotated[Optional[Union[tuple(literal_types)]], PartialLiteralValidator(annotation)] # type: ignore
tmp_field.default = None
else:
modified_args = tuple(
_process_generic_arg(arg, make_fields_optional=True)
for arg in generic_args
)
tmp_annotation: Any = Optional[generic_base[modified_args]] if generic_base else None # type: ignore

# Reconstruct the generic type with modified arguments
tmp_field.annotation = (
Optional[generic_base[modified_args]] if generic_base else None
)
tmp_field.default = None
# Reconstruct the generic type with modified arguments
tmp_field.annotation = tmp_annotation
tmp_field.default = None
# If the field is a BaseModel, then recursively convert it's
# attributes to optionals.
elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
Expand Down
53 changes: 51 additions & 2 deletions tests/dsl/test_partial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# type: ignore[all]
from pydantic import BaseModel, Field
from typing import Optional, Union
from pydantic import BaseModel, Field, ValidationError, validator

Check failure on line 2 in tests/dsl/test_partial.py

View workflow job for this annotation

GitHub Actions / Ruff (ubuntu-latest)

Ruff (F401)

tests/dsl/test_partial.py:2:57: F401 `pydantic.validator` imported but unused
from typing import Optional, Union, Literal, Annotated
from instructor.dsl.partial import Partial, PartialLiteralMixin
import pytest
import instructor
Expand Down Expand Up @@ -39,6 +39,19 @@
c: NestedB


SampleEnum = Literal["a_value", "b_value", "c_value"]
SampleMixedEnum = Literal["a_value", "b_value", "c_value", 1, 2, 3]


class PartialEnums(BaseModel):
a: Annotated[Literal["a_value"], Partial]
b: Annotated[SampleEnum, Partial]
c: Annotated[SampleMixedEnum, Partial]
d: Annotated[Literal["a_value", 10], Partial]
e: Annotated[Literal["a_value"], Partial]
f: Literal["a_value"]


def test_partial():
partial = Partial[SamplePartial]
assert partial.model_json_schema() == {
Expand Down Expand Up @@ -192,3 +205,39 @@
partial.get_partial_model().model_validate_json(
'{"a": [{"b": "b"}, {"d": "d"}], "b": [{"b": "b"}], "c": {"d": "d"}, "e": [1, "a"]}'
)


def test_partial_enums():
# Test that we can annotate enum values with `Partial` and support parsing
# partial values with the partial model
partial = Partial[PartialEnums]
partial_results = (
'{"a": "a_", "b": "b_", "c": "c_v", "d": 1, "e": "a_", "f": "a_value"}'
)
partial_validated = partial.get_partial_model().model_validate_json(partial_results)

assert partial_validated.a is None
assert partial_validated.b is None
assert partial_validated.c is None
assert partial_validated.d is None
assert partial_validated.e is None
assert partial_validated.f == "a_value"

with pytest.raises(ValidationError):
partial.model_validate_json(partial_results)

with pytest.raises(ValidationError):
# "f" is not marked as a partil enum
partial.get_partial_model().model_validate_json('{"f": "a_"}')

resolved_enum_partial_results = (
'{"a": "a_value", "b": "b_value", "c": "c_v", "d": 10}'
)
resolved_enum_partial_validated = partial.get_partial_model().model_validate_json(
resolved_enum_partial_results
)
assert resolved_enum_partial_validated.a == "a_value"
assert resolved_enum_partial_validated.b == "b_value"
# this value still isn't fully resolved
assert resolved_enum_partial_validated.c is None
assert resolved_enum_partial_validated.d == 10
Loading