Skip to content

Commit

Permalink
Handle PydanticUndefined, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dqbd committed Dec 10, 2024
1 parent 1f68bd0 commit 5a30fc6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
8 changes: 7 additions & 1 deletion libs/langgraph/langgraph/utils/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_field_default(name: str, type_: Any, schema: Type[Any]) -> Any:

def get_enhanced_type_hints(
type: Type[Any],
) -> Generator[tuple[str, Any, Any, Optional[str]], None]:
) -> Generator[tuple[str, Any, Any, Optional[str]], None, None]:
"""Attempt to extract default values and descriptions from provided type, used for config schema."""
for name, typ in get_type_hints(type).items():
default = None
Expand All @@ -126,6 +126,12 @@ def get_enhanced_type_hints(

if hasattr(field, "default") and field.default is not None:
default = field.default
if (
hasattr(default, "__class__")
and getattr(default.__class__, "__name__", "")
== "PydanticUndefinedType"
):
default = None

except (AttributeError, KeyError, TypeError):
pass
Expand Down
60 changes: 59 additions & 1 deletion libs/langgraph/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.utils.fields import _is_optional_type, get_field_default
from langgraph.utils.fields import (
_is_optional_type,
get_enhanced_type_hints,
get_field_default,
)
from langgraph.utils.runnable import is_async_callable, is_async_generator

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -227,3 +231,57 @@ class MyGrandChildDict(MyChildDict, total=False):
assert get_field_default("val_12", gcannos["val_12"], MyGrandChildDict) is None
assert get_field_default("val_9", gcannos["val_9"], MyGrandChildDict) is None
assert get_field_default("val_13", gcannos["val_13"], MyGrandChildDict) == ...


def test_enhanced_type_hints() -> None:
from dataclasses import dataclass
from typing import Annotated

from pydantic import BaseModel, Field

class MyTypedDict(TypedDict):
val_1: str
val_2: int = 42
val_3: str = "default"

hints = list(get_enhanced_type_hints(MyTypedDict))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, None)
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", None)

@dataclass
class MyDataclass:
val_1: str
val_2: int = 42
val_3: str = "default"

hints = list(get_enhanced_type_hints(MyDataclass))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, None)
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", None)

class MyPydanticModel(BaseModel):
val_1: str
val_2: int = 42
val_3: str = Field(default="default", description="A description")

hints = list(get_enhanced_type_hints(MyPydanticModel))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, None)
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", "A description")

class MyPydanticModelWithAnnotated(BaseModel):
val_1: Annotated[str, Field(description="A description")]
val_2: Annotated[int, Field(default=42)]
val_3: Annotated[
str, Field(default="default", description="Another description")
]

hints = list(get_enhanced_type_hints(MyPydanticModelWithAnnotated))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, "A description")
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", "Another description")

0 comments on commit 5a30fc6

Please sign in to comment.