diff --git a/funcchain/utils/function_frame.py b/funcchain/utils/function_frame.py index 6a9e795..32d9033 100644 --- a/funcchain/utils/function_frame.py +++ b/funcchain/utils/function_frame.py @@ -1,3 +1,4 @@ +import types from inspect import FrameInfo, currentframe, getouterframes from typing import Union @@ -29,12 +30,15 @@ def get_output_type() -> type: raise ValueError("The funcchain must have a return type annotation") -def parser_for(output_type: type) -> BaseOutputParser: +def parser_for(output_type: type) -> BaseOutputParser | None: """ Get the parser from the type annotation of the parent caller function. """ + if isinstance(output_type, types.UnionType): + return None if getattr(output_type, "__origin__", None) is Union: output_type = output_type.__args__[0] # type: ignore + return None if output_type is str: return StrOutputParser() if output_type is bool: