From c913e57c9279b93683411ca161ac4822c8bf09e5 Mon Sep 17 00:00:00 2001 From: Shroominic Date: Sat, 11 Nov 2023 22:36:07 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20Refactor=20chain=20invocation=20?= =?UTF-8?q?logic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- funcchain/chain.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/funcchain/chain.py b/funcchain/chain.py index 11a0f89..c849bc0 100644 --- a/funcchain/chain.py +++ b/funcchain/chain.py @@ -147,11 +147,14 @@ def chain( """ Get response from chatgpt for provided instructions. """ + chain = create_chain(instruction, system, parser, context, input_kwargs) + with get_openai_callback() as cb: - chain = create_chain(instruction, system, parser, context, input_kwargs).invoke(input_kwargs) + result = chain.invoke(input_kwargs) if cb.total_tokens != 0: log(f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}") - return chain + + return result @retry_parse(5) @@ -165,8 +168,11 @@ async def achain( """ Get response from chatgpt for provided instructions. """ + chain = create_chain(instruction, system, parser, context, input_kwargs) + with get_openai_callback() as cb: - chain = await create_chain(instruction, system, parser, context, input_kwargs).ainvoke(input_kwargs) + result = await chain.ainvoke(input_kwargs) if cb.total_tokens != 0: log(f"{cb.total_tokens:05}T / {cb.total_cost:.3f}$ - {get_parent_frame(3).function}") - return chain + + return result