Skip to content

Commit

Permalink
Merge branch 'main' of github.com:navapbc/labs-gen-ai-experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
ccheng26 committed Apr 5, 2024
2 parents 66f56f6 + 10d4245 commit 45c61ca
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
8 changes: 6 additions & 2 deletions 02-household-queries/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from langchain.chains import RetrievalQA


def create_retriever(vectordb):
retrieve_k = int(os.environ.get("RETRIEVE_K", "1"))
return vectordb.as_retriever(search_kwargs={"k": retrieve_k})


def retrieval_call(llm, vectordb, question):
# Create the retrieval chain
template = """
Expand All @@ -14,8 +19,7 @@ def retrieval_call(llm, vectordb, question):
print("\n## PROMPT TEMPLATE: ", llm_prompt)

prompt = PromptTemplate.from_template(llm_prompt)
retrieve_k = int(os.environ.get("RETRIEVE_K", "1"))
retriever = vectordb.as_retriever(search_kwargs={"k": retrieve_k})
retriever = create_retriever(vectordb)
retrieval_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
Expand Down
56 changes: 55 additions & 1 deletion 02-household-queries/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import json
import dotenv
from langchain_community.embeddings import (
SentenceTransformerEmbeddings,
Expand All @@ -9,7 +10,7 @@
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings

from ingest import ingest_call
from retrieval import retrieval_call
from retrieval import create_retriever, retrieval_call
from llm import ollama_client

dotenv.load_dotenv()
Expand Down Expand Up @@ -69,17 +70,70 @@
persist_directory="./chroma_db",
)


def load_training_json():
with open("question_answer_citations.json", encoding="utf-8") as data_file:
json_data = json.load(data_file)
# print(json.dumps(json_data, indent=2))
return json_data


def compute_percent_retrieved(retrieved_cards, guru_cards):
missed_cards = set(guru_cards) - set(retrieved_cards)
return (len(guru_cards) - len(missed_cards)) / len(guru_cards)


def count_extra_cards(retrieved_cards, guru_cards):
extra_cards = set(retrieved_cards) - set(guru_cards)
return len(extra_cards)


def evaluate_retrieval():
qa = load_training_json()
results = []
retriever = create_retriever(vectordb)
for qa_dict in qa[1:]:
orig_question = qa_dict["orig_question"]
question = qa_dict.get("question", orig_question)
# print(f"\nQUESTION {qa_dict['id']}: {question}")
guru_cards = qa_dict.get("guru_cards", [])
# print(f" Desired CARDS : {guru_cards}")

retrieval = retriever.invoke(question)
retrieved_cards = [doc.metadata["source"] for doc in retrieval]
results.append(
{
"id": qa_dict["id"],
"question": question,
"guru_cards": guru_cards,
"retrieved_cards": retrieved_cards,
"recall": compute_percent_retrieved(retrieved_cards, guru_cards),
"extra_cards": count_extra_cards(retrieved_cards, guru_cards),
}
)
print(retriever)
print(
"EVALUATION RESULTS:\n", "\n".join([json.dumps(r, indent=2) for r in results])
)
print("\nTable:")
for res in results:
print(res["id"], "|", res["recall"], "|", res["extra_cards"])


print("""
Initialize DB and retrieve?
1. Retrieve only (default)
2. Ingest and retrieve
3. Ingest only
4. Evaluate retrieval
""")
run_option = input()
if run_option == "2":
ingest_call(vectordb=vectordb)
retrieval_call(llm=llm, vectordb=vectordb)
elif run_option == "3":
ingest_call(vectordb=vectordb)
elif run_option == "4":
evaluate_retrieval()
else:
retrieval_call(llm=llm, vectordb=vectordb)

0 comments on commit 45c61ca

Please sign in to comment.