diff --git a/libs/langgraph/langgraph/utils/fields.py b/libs/langgraph/langgraph/utils/fields.py index 503d4c2d4..009e5aee1 100644 --- a/libs/langgraph/langgraph/utils/fields.py +++ b/libs/langgraph/langgraph/utils/fields.py @@ -110,7 +110,7 @@ def get_field_default(name: str, type_: Any, schema: Type[Any]) -> Any: def get_enhanced_type_hints( type: Type[Any], -) -> Generator[tuple[str, Any, Any, Optional[str]], None]: +) -> Generator[tuple[str, Any, Any, Optional[str]], None, None]: """Attempt to extract default values and descriptions from provided type, used for config schema.""" for name, typ in get_type_hints(type).items(): default = None @@ -126,6 +126,12 @@ def get_enhanced_type_hints( if hasattr(field, "default") and field.default is not None: default = field.default + if ( + hasattr(default, "__class__") + and getattr(default.__class__, "__name__", "") + == "PydanticUndefinedType" + ): + default = None except (AttributeError, KeyError, TypeError): pass diff --git a/libs/langgraph/tests/test_utils.py b/libs/langgraph/tests/test_utils.py index e8ea94fff..616f1a78f 100644 --- a/libs/langgraph/tests/test_utils.py +++ b/libs/langgraph/tests/test_utils.py @@ -21,7 +21,11 @@ from langgraph.graph import END, StateGraph from langgraph.graph.graph import CompiledGraph -from langgraph.utils.fields import _is_optional_type, get_field_default +from langgraph.utils.fields import ( + _is_optional_type, + get_enhanced_type_hints, + get_field_default, +) from langgraph.utils.runnable import is_async_callable, is_async_generator pytestmark = pytest.mark.anyio @@ -227,3 +231,57 @@ class MyGrandChildDict(MyChildDict, total=False): assert get_field_default("val_12", gcannos["val_12"], MyGrandChildDict) is None assert get_field_default("val_9", gcannos["val_9"], MyGrandChildDict) is None assert get_field_default("val_13", gcannos["val_13"], MyGrandChildDict) == ... + + +def test_enhanced_type_hints() -> None: + from dataclasses import dataclass + from typing import Annotated + + from pydantic import BaseModel, Field + + class MyTypedDict(TypedDict): + val_1: str + val_2: int = 42 + val_3: str = "default" + + hints = list(get_enhanced_type_hints(MyTypedDict)) + assert len(hints) == 3 + assert hints[0] == ("val_1", str, None, None) + assert hints[1] == ("val_2", int, 42, None) + assert hints[2] == ("val_3", str, "default", None) + + @dataclass + class MyDataclass: + val_1: str + val_2: int = 42 + val_3: str = "default" + + hints = list(get_enhanced_type_hints(MyDataclass)) + assert len(hints) == 3 + assert hints[0] == ("val_1", str, None, None) + assert hints[1] == ("val_2", int, 42, None) + assert hints[2] == ("val_3", str, "default", None) + + class MyPydanticModel(BaseModel): + val_1: str + val_2: int = 42 + val_3: str = Field(default="default", description="A description") + + hints = list(get_enhanced_type_hints(MyPydanticModel)) + assert len(hints) == 3 + assert hints[0] == ("val_1", str, None, None) + assert hints[1] == ("val_2", int, 42, None) + assert hints[2] == ("val_3", str, "default", "A description") + + class MyPydanticModelWithAnnotated(BaseModel): + val_1: Annotated[str, Field(description="A description")] + val_2: Annotated[int, Field(default=42)] + val_3: Annotated[ + str, Field(default="default", description="Another description") + ] + + hints = list(get_enhanced_type_hints(MyPydanticModelWithAnnotated)) + assert len(hints) == 3 + assert hints[0] == ("val_1", str, None, "A description") + assert hints[1] == ("val_2", int, 42, None) + assert hints[2] == ("val_3", str, "default", "Another description")