From 4a0083554804b35057ea1e3e626f01aef6afa10c Mon Sep 17 00:00:00 2001 From: Baptiste Pasquier Date: Mon, 15 Apr 2024 18:50:06 +0200 Subject: [PATCH] :bug: fix pydantic error in rag_basic --- backend/rag_components/chain_links/rag_basic.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/backend/rag_components/chain_links/rag_basic.py b/backend/rag_components/chain_links/rag_basic.py index bf6b2a7..07b90be 100644 --- a/backend/rag_components/chain_links/rag_basic.py +++ b/backend/rag_components/chain_links/rag_basic.py @@ -1,4 +1,5 @@ """This chain answers the provided question based on documents it retreives.""" + from langchain_core.prompts import ChatPromptTemplate from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import RunnablePassthrough @@ -13,7 +14,7 @@ Context: {relevant_documents} Question: {question} -""" # noqa: E501 +""" # noqa: E501 class Question(BaseModel): @@ -26,9 +27,11 @@ class Response(BaseModel): def rag_basic(llm, retriever: BaseRetriever) -> DocumentedRunnable: chain = ( - {"relevant_documents": fetch_docs_chain(retriever), "question": RunnablePassthrough(Question)} + {"relevant_documents": fetch_docs_chain(retriever), "question": RunnablePassthrough()} | ChatPromptTemplate.from_template(prompt) | llm ) - typed_chain = chain.with_types(input_type=str, output_type=Response) - return DocumentedRunnable(typed_chain, chain_name="Answer questions from documents stored in a vector store", prompt=prompt, user_doc=__doc__) + typed_chain = chain.with_types(input_type=Question, output_type=Response) + return DocumentedRunnable( + typed_chain, chain_name="Answer questions from documents stored in a vector store", prompt=prompt, user_doc=__doc__ + )