diff --git a/02-household-queries/retrieval.py b/02-household-queries/retrieval.py index 0cc2ed5..94851d9 100644 --- a/02-household-queries/retrieval.py +++ b/02-household-queries/retrieval.py @@ -3,6 +3,11 @@ from langchain.chains import RetrievalQA +def create_retriever(vectordb): + retrieve_k = int(os.environ.get("RETRIEVE_K", "1")) + return vectordb.as_retriever(search_kwargs={"k": retrieve_k}) + + def retrieval_call(llm, vectordb, question): # Create the retrieval chain template = """ @@ -14,8 +19,7 @@ def retrieval_call(llm, vectordb, question): print("\n## PROMPT TEMPLATE: ", llm_prompt) prompt = PromptTemplate.from_template(llm_prompt) - retrieve_k = int(os.environ.get("RETRIEVE_K", "1")) - retriever = vectordb.as_retriever(search_kwargs={"k": retrieve_k}) + retriever = create_retriever(vectordb) retrieval_chain = RetrievalQA.from_chain_type( llm=llm, retriever=retriever, diff --git a/02-household-queries/run.py b/02-household-queries/run.py index d3c016e..1b1862a 100644 --- a/02-household-queries/run.py +++ b/02-household-queries/run.py @@ -1,4 +1,5 @@ import os +import json import dotenv from langchain_community.embeddings import ( SentenceTransformerEmbeddings, @@ -9,7 +10,7 @@ from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings from ingest import ingest_call -from retrieval import retrieval_call +from retrieval import create_retriever, retrieval_call from llm import ollama_client dotenv.load_dotenv() @@ -69,11 +70,62 @@ persist_directory="./chroma_db", ) + +def load_training_json(): + with open("question_answer_citations.json", encoding="utf-8") as data_file: + json_data = json.load(data_file) + # print(json.dumps(json_data, indent=2)) + return json_data + + +def compute_percent_retrieved(retrieved_cards, guru_cards): + missed_cards = set(guru_cards) - set(retrieved_cards) + return (len(guru_cards) - len(missed_cards)) / len(guru_cards) + + +def count_extra_cards(retrieved_cards, guru_cards): + extra_cards = set(retrieved_cards) - set(guru_cards) + return len(extra_cards) + + +def evaluate_retrieval(): + qa = load_training_json() + results = [] + retriever = create_retriever(vectordb) + for qa_dict in qa[1:]: + orig_question = qa_dict["orig_question"] + question = qa_dict.get("question", orig_question) + # print(f"\nQUESTION {qa_dict['id']}: {question}") + guru_cards = qa_dict.get("guru_cards", []) + # print(f" Desired CARDS : {guru_cards}") + + retrieval = retriever.invoke(question) + retrieved_cards = [doc.metadata["source"] for doc in retrieval] + results.append( + { + "id": qa_dict["id"], + "question": question, + "guru_cards": guru_cards, + "retrieved_cards": retrieved_cards, + "recall": compute_percent_retrieved(retrieved_cards, guru_cards), + "extra_cards": count_extra_cards(retrieved_cards, guru_cards), + } + ) + print(retriever) + print( + "EVALUATION RESULTS:\n", "\n".join([json.dumps(r, indent=2) for r in results]) + ) + print("\nTable:") + for res in results: + print(res["id"], "|", res["recall"], "|", res["extra_cards"]) + + print(""" Initialize DB and retrieve? 1. Retrieve only (default) 2. Ingest and retrieve 3. Ingest only +4. Evaluate retrieval """) run_option = input() if run_option == "2": @@ -81,5 +133,7 @@ retrieval_call(llm=llm, vectordb=vectordb) elif run_option == "3": ingest_call(vectordb=vectordb) +elif run_option == "4": + evaluate_retrieval() else: retrieval_call(llm=llm, vectordb=vectordb)