Skip to content

Commit

Permalink
Eval related tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
Taqi Jaffri committed Dec 6, 2023
1 parent a4f59d2 commit 885692a
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 83 deletions.
8 changes: 4 additions & 4 deletions docugami_kg_rag/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ class LocalIndexState:

# Lengths for the loader are in terms of characters, 1 token ~= 4 chars in English
# Reference: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
MAX_CHUNK_TEXT_LENGTH = 1024 * 20 # ~5k tokens
MIN_CHUNK_TEXT_LENGTH = 1024 * 8 # ~1k tokens
MAX_CHUNK_TEXT_LENGTH = 1024 * 28 # ~7k tokens
MIN_CHUNK_TEXT_LENGTH = 1024 * 1 # ~1k tokens
SUB_CHUNK_TABLES = False
INCLUDE_XML_TAGS = True
PARENT_HIERARCHY_LEVELS = 1
RETRIEVER_K = 10
PARENT_HIERARCHY_LEVELS = 4
RETRIEVER_K = 6

BATCH_SIZE = 16
13 changes: 10 additions & 3 deletions docugami_kg_rag/helpers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,23 @@
All your answers must contain citations to help the user understand how you created the citation, specifically:
- If the given context contains the names of document(s), make sure you include that in your answer as
a citation, e.g. include "\\n\\nSOURCE(S): foo.pdf, bar.pdf" at the end of your answer.
- If the given context contains the names of document(s), make sure you include the document you got the
answer from as a citation, e.g. include "\\n\\nSOURCE(S): foo.pdf, bar.pdf" at the end of your answer.
- If the answer was generated via a SQL Query, make sure you include the SQL query in your answer as
a citation, e.g. include "\\n\\nSOURCE(S): SELECT AVG('square footage') from Leases". The SQL query should be
in the agent scratchpad provided.
in the agent scratchpad provided, if you are using an agent.
- Make sure there an actual answer if you show a SOURCE citation, i.e. make sure you don't show only
a bare citation with no actual answer.
"""

HUMAN_MESSAGE_TEMPLATE = """{context}
Using the context above, which can include text and tables, answer the following question.
Question: {question}
"""

CREATE_DIRECT_RETRIEVAL_TOOL_DESCRIPTION_PROMPT = """Here is a snippet from a sample document of type {docset_name}:
{document}
Expand Down
31 changes: 19 additions & 12 deletions docugami_kg_rag/helpers/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from langchain.agents.agent_toolkits import create_retriever_tool
from langchain.prompts import ChatPromptTemplate
from langchain.schema import Document, StrOutputParser
from langchain.schema import BaseRetriever, Document, StrOutputParser
from langchain.tools.base import BaseTool
from langchain.vectorstores import Chroma

Expand All @@ -25,6 +25,22 @@
)


def get_retriever_for_docset(docset_state: LocalIndexState) -> BaseRetriever:
"""
Gets a retriever for a docset. Chunks are in the vector store, and full documents
are in the store inside the local state.
"""
chunk_vectorstore = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=EMBEDDINGS)

return FusedSummaryRetriever(
vectorstore=chunk_vectorstore,
parent_doc_store=docset_state.chunks_by_id,
full_doc_summary_store=docset_state.full_doc_summaries_by_id,
search_kwargs={"k": RETRIEVER_K},
search_type=SearchType.mmr,
)


def docset_name_to_direct_retriever_tool_function_name(name: str) -> str:
"""
Converts a docset name to a direct retriever tool function name.
Expand Down Expand Up @@ -75,19 +91,10 @@ def chunks_to_direct_retriever_tool_description(name: str, chunks: List[Document

def get_retrieval_tool_for_docset(docset_state: LocalIndexState) -> Optional[BaseTool]:
"""
Chunks are in the vector store, and full documents are in the store inside the local state
Gets a retrieval tool for an agent.
"""

chunk_vectorstore = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=EMBEDDINGS)

retriever = FusedSummaryRetriever(
vectorstore=chunk_vectorstore,
parent_doc_store=docset_state.chunks_by_id,
full_doc_summary_store=docset_state.full_doc_summaries_by_id,
search_kwargs={"k": RETRIEVER_K},
search_type=SearchType.mmr,
)

retriever = get_retriever_for_docset(docset_state=docset_state)
return create_retriever_tool(
retriever=retriever,
name=docset_state.retrieval_tool_function_name,
Expand Down
142 changes: 78 additions & 64 deletions evals/run-evals.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,25 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set up Eval Dataset"
"## Set up Eval"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"\n",
"DOCSET_NAME = \"Earnings Calls Evaluation\"\n",
"# Important: Create your OpenAI assistant via https://platform.openai.com/playground\n",
"# and put the assistant ID here. Make sure you upload the identical set of\n",
"# files listed below (these files will be uploaded automatically to Docugami)\n",
"OPENAI_ASSISTANT_ID = \"asst_g837jjwr6Ohgk2EWfQOKTcPg\"\n",
"\n",
"DOCSET_NAME = \"Earnings Calls Evaluation 12-06-2023\"\n",
"FILES_DIR = Path(os.getcwd()) / \"v1/docs\"\n",
"FILE_NAMES = [\n",
" \"Q1 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf\",\n",
" \"Q1 2023 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf\",\n",
Expand All @@ -36,10 +42,12 @@
" \"Q4 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf\",\n",
" \"Q3 FY23 Microsoft Corp Earnings Call.pdf\",\n",
"]\n",
"\n",
"FILES_DIR = Path(os.getcwd()) / \"v1/docs\"\n",
"GROUND_TRUTH_CSV = Path(os.getcwd()) / \"v1/ground-truth-earning_calls.csv\"\n",
"\n",
"# We will run each experiment multiple times and average, \n",
"# since results vary slightly over runs\n",
"PER_EXPERIMENT_RUN_COUNT = 5\n",
"\n",
"# Note: Please specify ~6 (or more!) similar files to process together as a document set\n",
"# This is currently a requirement for Docugami to automatically detect motifs\n",
"# across the document set to generate a semantic XML Knowledge Graph.\n",
Expand All @@ -48,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -85,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -102,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -130,9 +138,29 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"{'Q1 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmp19ctz_zz',\n",
" 'Q1 2023 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmp1fcb78wn',\n",
" 'Q2 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmpw3zclms2',\n",
" 'Q2 2023 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmprqly0der',\n",
" 'Q3 2021 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmp00rntcqk',\n",
" 'Q3 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmpboz8mq6c',\n",
" 'Q3 2023 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmpseqe4ojt',\n",
" 'Q4 2020 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmph9dhg7fi',\n",
" 'Q4 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmp9jufpgk0',\n",
" 'Q3 FY23 Microsoft Corp Earnings Call.pdf': '/tmp/tmp14olooto'}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Wait for files to finish processing (OCR, and zero-shot creation of XML knowledge graph)\n",
"\n",
Expand All @@ -142,9 +170,25 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Indexing Earnings Calls Evaluation 12-06-2023 (ID: l4ebpbn3ugk0)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Creating full document summaries in batches: 100%|██████████| 1/1 [03:10<00:00, 190.34s/it]\n",
"Creating chunk summaries in batches: 11%|█ | 4/38 [01:35<13:24, 23.65s/it]"
]
}
],
"source": [
"# Run indexing\n",
"from docugami_kg_rag.helpers.indexing import index_docset\n",
Expand Down Expand Up @@ -219,41 +263,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Upload files to OpenAI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from openai import files\n",
"### Create OpenAI Agent\n",
"\n",
"existing_files = files.list().data\n",
"existing_file_names = [f.filename for f in existing_files]\n",
"uploaded_files = []\n",
"for file_name in FILE_NAMES:\n",
" if file_name in existing_file_names:\n",
" # file was previously uploaded\n",
" uploaded_files.append([f for f in existing_files if f.filename == file_name][0])\n",
" else:\n",
" # upload\n",
" file_path = FILES_DIR / file_name\n",
" file = files.create(file=file_path, purpose='assistants')\n",
" uploaded_files.append(file)\n",
"\n",
"file_ids=[f.id for f in uploaded_files]\n",
"\n",
"uploaded_files\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create OpenAI Assistant Agent"
"Please go to https://platform.openai.com/playground and create your agent. "
]
},
{
Expand All @@ -263,20 +275,11 @@
"outputs": [],
"source": [
"from langchain.agents.openai_assistant import OpenAIAssistantRunnable\n",
"openai_agent = OpenAIAssistantRunnable.create_assistant(\n",
" name=\"Earnings Call Assistant\",\n",
" instructions=\"An assistant that specializes in answering questions based only on the given knowledge base of earnings calls\",\n",
" tools=[{\"type\": \"retrieval\"}],\n",
" model=\"gpt-4-1106-preview\",\n",
")\n",
"openai_agent.as_agent = True\n",
"\n",
"def predict_openai_agent(input: dict) -> dict:\n",
"def predict_openai_agent(input: dict, config: dict = None) -> dict:\n",
" openai_agent = OpenAIAssistantRunnable(assistant_id=OPENAI_ASSISTANT_ID, as_agent=True).with_config(config)\n",
" question = input[\"question\"]\n",
" result = openai_agent.invoke({\n",
" \"content\": question,\n",
" \"file_ids\": file_ids,\n",
" })\n",
" result = openai_agent.invoke({\"content\": question})\n",
"\n",
" return result.return_values[\"output\"]"
]
Expand Down Expand Up @@ -307,9 +310,10 @@
"import uuid\n",
"from langsmith.client import Client\n",
"from langchain.smith import RunEvalConfig\n",
"from langchain.globals import set_llm_cache, get_llm_cache\n",
"\n",
"eval_config = RunEvalConfig(\n",
" evaluators=[\"cot_qa\"],\n",
" evaluators=[\"qa\"],\n",
")\n",
"\n",
"def run_eval(eval_func, eval_run_name):\n",
Expand All @@ -328,14 +332,24 @@
"\n",
"# Experiments\n",
"agent_map = {\n",
" # \"openai_assistant_retrieval\": predict_openai_agent,\n",
" \"docugami_kg_rag_zero_shot\": predict_docugami_agent,\n",
" \"openai_assistant_retrieval\": predict_openai_agent,\n",
"}\n",
"\n",
"for _ in range(10):\n",
" run_id = str(uuid.uuid4())\n",
" for project_name, agent in agent_map.items():\n",
" run_eval(agent, project_name + \"_\" + run_id)"
"try:\n",
" # Disable global cache setting to get fresh results every time for all experiments\n",
" # since no caching or temperature-0 is supported for the openai assistants API and\n",
" # we want to measure under similar conditions\n",
" cache = get_llm_cache()\n",
" set_llm_cache(None)\n",
"\n",
" for i in range(PER_EXPERIMENT_RUN_COUNT):\n",
" run_id = str(uuid.uuid4())\n",
" for project_name, agent in agent_map.items():\n",
" run_eval(agent, project_name + \"_\" + run_id)\n",
"finally:\n",
" # Revert cache setting to global default\n",
" set_llm_cache(cache)\n"
]
}
],
Expand Down

0 comments on commit 885692a

Please sign in to comment.