Skip to content

Commit

Permalink
🐛 fix pydantic error in rag_basic
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Apr 15, 2024
1 parent 7868535 commit 4a00835
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions backend/rag_components/chain_links/rag_basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This chain answers the provided question based on documents it retreives."""

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import RunnablePassthrough
Expand All @@ -13,7 +14,7 @@
Context: {relevant_documents}
Question: {question}
""" # noqa: E501
""" # noqa: E501


class Question(BaseModel):
Expand All @@ -26,9 +27,11 @@ class Response(BaseModel):

def rag_basic(llm, retriever: BaseRetriever) -> DocumentedRunnable:
chain = (
{"relevant_documents": fetch_docs_chain(retriever), "question": RunnablePassthrough(Question)}
{"relevant_documents": fetch_docs_chain(retriever), "question": RunnablePassthrough()}
| ChatPromptTemplate.from_template(prompt)
| llm
)
typed_chain = chain.with_types(input_type=str, output_type=Response)
return DocumentedRunnable(typed_chain, chain_name="Answer questions from documents stored in a vector store", prompt=prompt, user_doc=__doc__)
typed_chain = chain.with_types(input_type=Question, output_type=Response)
return DocumentedRunnable(
typed_chain, chain_name="Answer questions from documents stored in a vector store", prompt=prompt, user_doc=__doc__
)

0 comments on commit 4a00835

Please sign in to comment.