diff --git a/instructor/dsl/partial.py b/instructor/dsl/partial.py index eaf986713..a6cf7b7f9 100644 --- a/instructor/dsl/partial.py +++ b/instructor/dsl/partial.py @@ -83,12 +83,14 @@ def get_partial_model(cls) -> type[T_Model]: cls, BaseModel ), f"{cls.__name__} must be a subclass of BaseModel" + model_name = ( + cls.__name__ + if cls.__name__.startswith("Partial") + else f"Partial{cls.__name__}" + ) + return create_model( - __model_name=( - cls.__name__ - if cls.__name__.startswith("Partial") - else f"Partial{cls.__name__}" - ), + model_name, __base__=cls, __module__=cls.__module__, **{ @@ -289,12 +291,14 @@ def _wrap_models(field: FieldInfo) -> tuple[object, FieldInfo]: tmp_field.annotation = Partial[annotation] return tmp_field.annotation, tmp_field + model_name = ( + wrapped_class.__name__ + if wrapped_class.__name__.startswith("Partial") + else f"Partial{wrapped_class.__name__}" + ) + return create_model( - __model_name=( - wrapped_class.__name__ - if wrapped_class.__name__.startswith("Partial") - else f"Partial{wrapped_class.__name__}" - ), + model_name, __base__=(wrapped_class, PartialBase), __module__=wrapped_class.__module__, **{