diff --git a/libs/langgraph/langgraph/pregel/__init__.py b/libs/langgraph/langgraph/pregel/__init__.py index 2cd7f0984..ba4533830 100644 --- a/libs/langgraph/langgraph/pregel/__init__.py +++ b/libs/langgraph/langgraph/pregel/__init__.py @@ -11,7 +11,6 @@ AsyncIterator, Callable, Dict, - Generator, Iterator, Mapping, Optional, @@ -19,7 +18,6 @@ Type, Union, cast, - get_type_hints, overload, ) from uuid import UUID, uuid5 @@ -118,6 +116,7 @@ patch_config, patch_configurable, ) +from langgraph.utils.fields import get_enhanced_type_hints from langgraph.utils.pydantic import create_model from langgraph.utils.queue import AsyncQueue, SyncQueue # type: ignore[attr-defined] @@ -309,44 +308,6 @@ def validate(self) -> Self: @property def config_specs(self) -> list[ConfigurableFieldSpec]: - # TODO: shouldn't this be in langchain_core? - def get_enhanced_type_hints( - type: Type[Any], - ) -> Generator[tuple[str, Any, Any, Optional[str]], None]: - """Attempt to extract default values and descriptions from provided config spec""" - for name, typ in get_type_hints(type).items(): - default = None - description = None - - # Pydantic models - try: - if hasattr(type, "__fields__") and name in type.__fields__: - field = type.__fields__[name] - - if ( - hasattr(field, "description") - and field.description is not None - ): - description = field.description - - if hasattr(field, "default") and field.default is not None: - default = field.default - - except (AttributeError, KeyError, TypeError): - pass - - # TypedDict, dataclass - try: - if hasattr(type, "__dict__"): - type_dict = getattr(type, "__dict__") - - if name in type_dict: - default = type_dict[name] - except (AttributeError, KeyError, TypeError): - pass - - yield name, typ, default, description - return [ spec for spec in get_unique_config_specs( diff --git a/libs/langgraph/langgraph/pregel/utils.py b/libs/langgraph/langgraph/pregel/utils.py index 66464ef9a..0c7030bb0 100644 --- a/libs/langgraph/langgraph/pregel/utils.py +++ b/libs/langgraph/langgraph/pregel/utils.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Generator, Optional, Type, get_type_hints from langchain_core.runnables import RunnableLambda, RunnableSequence from langchain_core.runnables.utils import get_function_nonlocals diff --git a/libs/langgraph/langgraph/utils/fields.py b/libs/langgraph/langgraph/utils/fields.py index f4786cb34..503d4c2d4 100644 --- a/libs/langgraph/langgraph/utils/fields.py +++ b/libs/langgraph/langgraph/utils/fields.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, Optional, Type, Union +from typing import Any, Generator, Optional, Type, Union, get_type_hints from typing_extensions import Annotated, NotRequired, ReadOnly, Required, get_origin @@ -106,3 +106,38 @@ def get_field_default(name: str, type_: Any, schema: Type[Any]) -> Any: if _is_optional_type(type_): return None return ... + + +def get_enhanced_type_hints( + type: Type[Any], +) -> Generator[tuple[str, Any, Any, Optional[str]], None]: + """Attempt to extract default values and descriptions from provided type, used for config schema.""" + for name, typ in get_type_hints(type).items(): + default = None + description = None + + # Pydantic models + try: + if hasattr(type, "__fields__") and name in type.__fields__: + field = type.__fields__[name] + + if hasattr(field, "description") and field.description is not None: + description = field.description + + if hasattr(field, "default") and field.default is not None: + default = field.default + + except (AttributeError, KeyError, TypeError): + pass + + # TypedDict, dataclass + try: + if hasattr(type, "__dict__"): + type_dict = getattr(type, "__dict__") + + if name in type_dict: + default = type_dict[name] + except (AttributeError, KeyError, TypeError): + pass + + yield name, typ, default, description