From c9fb89933f05b8981e93e7c41584c20d3e1dccff Mon Sep 17 00:00:00 2001 From: Shroominic Date: Wed, 8 Nov 2023 23:20:46 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20Improved=20type=20handling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- funcchain/utils/function_frame.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/funcchain/utils/function_frame.py b/funcchain/utils/function_frame.py index 6a9e795..32d9033 100644 --- a/funcchain/utils/function_frame.py +++ b/funcchain/utils/function_frame.py @@ -1,3 +1,4 @@ +import types from inspect import FrameInfo, currentframe, getouterframes from typing import Union @@ -29,12 +30,15 @@ def get_output_type() -> type: raise ValueError("The funcchain must have a return type annotation") -def parser_for(output_type: type) -> BaseOutputParser: +def parser_for(output_type: type) -> BaseOutputParser | None: """ Get the parser from the type annotation of the parent caller function. """ + if isinstance(output_type, types.UnionType): + return None if getattr(output_type, "__origin__", None) is Union: output_type = output_type.__args__[0] # type: ignore + return None if output_type is str: return StrOutputParser() if output_type is bool: