Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix streamed list of union basemodels #1141

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
NoReturn,
Optional,
TypeVar,
Union,
)
from collections.abc import AsyncGenerator, Generator, Iterable
from copy import deepcopy
Expand All @@ -38,41 +39,57 @@ def _make_field_optional(
field: FieldInfo,
) -> tuple[Any, FieldInfo]:
tmp_field = deepcopy(field)

annotation = field.annotation

# Handle generics (like List, Dict, etc.)
if get_origin(annotation) is not None:
# Get the generic base (like List, Dict) and its arguments (like User in List[User])
generic_base = get_origin(annotation)
generic_args = get_args(annotation)

# Recursively apply Partial to each of the generic arguments
modified_args = tuple(
(
Partial[arg, MakeFieldsOptional] # type: ignore[valid-type]
if isinstance(arg, type) and issubclass(arg, BaseModel)
# Handle Union types specially
if generic_base is Union:
modified_args = tuple(
Partial[arg, MakeFieldsOptional] if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
for arg in generic_args
)
for arg in generic_args
)

# Reconstruct the generic type with modified arguments
tmp_field.annotation = (
Optional[generic_base[modified_args]] if generic_base else None
)
tmp_field.default = None
# If the field is a BaseModel, then recursively convert it's
# attributes to optionals.
# Add None to Union options and set default
modified_args = modified_args + (None,) if None not in modified_args else modified_args
tmp_field.annotation = Union[modified_args] # type: ignore
tmp_field.default = None
else:
# For other generics (like List), process their arguments
modified_args = tuple(
_process_annotation(arg)
for arg in generic_args
)
tmp_field.annotation = Optional[generic_base[modified_args]] # type: ignore
tmp_field.default = None
# If the field is a BaseModel, then recursively convert it's attributes to optionals
elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
tmp_field.annotation = Optional[Partial[annotation, MakeFieldsOptional]] # type: ignore[assignment, valid-type]
tmp_field.default = {}
tmp_field.annotation = Optional[Partial[annotation, MakeFieldsOptional]] # type: ignore
tmp_field.default = None
else:
tmp_field.annotation = Optional[field.annotation] # type: ignore[assignment]
tmp_field.annotation = Optional[annotation] # type: ignore
tmp_field.default = None

return tmp_field.annotation, tmp_field # type: ignore

def _process_annotation(annotation: Any) -> Any:
"""Helper function to process nested annotations"""
if get_origin(annotation) is Union:
modified_args = tuple(
Partial[arg, MakeFieldsOptional] if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
for arg in get_args(annotation)
)
# Add None to Union options
modified_args = modified_args + (None,) if None not in modified_args else modified_args
return Union[modified_args]
elif isinstance(annotation, type) and issubclass(annotation, BaseModel):
return Partial[annotation, MakeFieldsOptional]
return annotation


class PartialBase(Generic[T_Model]):
@classmethod
Expand Down