From 885692a030f27865f734c89ad4a220831e845ab0 Mon Sep 17 00:00:00 2001 From: Taqi Jaffri Date: Wed, 6 Dec 2023 14:11:38 -0800 Subject: [PATCH] Eval related tweaks --- docugami_kg_rag/config.py | 8 +- docugami_kg_rag/helpers/prompts.py | 13 ++- docugami_kg_rag/helpers/retrieval.py | 31 +++--- evals/run-evals.ipynb | 142 +++++++++++++++------------ 4 files changed, 111 insertions(+), 83 deletions(-) diff --git a/docugami_kg_rag/config.py b/docugami_kg_rag/config.py index ff6ffdd..151ccc7 100644 --- a/docugami_kg_rag/config.py +++ b/docugami_kg_rag/config.py @@ -72,11 +72,11 @@ class LocalIndexState: # Lengths for the loader are in terms of characters, 1 token ~= 4 chars in English # Reference: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them -MAX_CHUNK_TEXT_LENGTH = 1024 * 20 # ~5k tokens -MIN_CHUNK_TEXT_LENGTH = 1024 * 8 # ~1k tokens +MAX_CHUNK_TEXT_LENGTH = 1024 * 28 # ~7k tokens +MIN_CHUNK_TEXT_LENGTH = 1024 * 1 # ~1k tokens SUB_CHUNK_TABLES = False INCLUDE_XML_TAGS = True -PARENT_HIERARCHY_LEVELS = 1 -RETRIEVER_K = 10 +PARENT_HIERARCHY_LEVELS = 4 +RETRIEVER_K = 6 BATCH_SIZE = 16 diff --git a/docugami_kg_rag/helpers/prompts.py b/docugami_kg_rag/helpers/prompts.py index 8ed7513..7c501df 100644 --- a/docugami_kg_rag/helpers/prompts.py +++ b/docugami_kg_rag/helpers/prompts.py @@ -9,16 +9,23 @@ All your answers must contain citations to help the user understand how you created the citation, specifically: -- If the given context contains the names of document(s), make sure you include that in your answer as - a citation, e.g. include "\\n\\nSOURCE(S): foo.pdf, bar.pdf" at the end of your answer. +- If the given context contains the names of document(s), make sure you include the document you got the + answer from as a citation, e.g. include "\\n\\nSOURCE(S): foo.pdf, bar.pdf" at the end of your answer. - If the answer was generated via a SQL Query, make sure you include the SQL query in your answer as a citation, e.g. include "\\n\\nSOURCE(S): SELECT AVG('square footage') from Leases". The SQL query should be - in the agent scratchpad provided. + in the agent scratchpad provided, if you are using an agent. - Make sure there an actual answer if you show a SOURCE citation, i.e. make sure you don't show only a bare citation with no actual answer. """ +HUMAN_MESSAGE_TEMPLATE = """{context} + +Using the context above, which can include text and tables, answer the following question. + +Question: {question} +""" + CREATE_DIRECT_RETRIEVAL_TOOL_DESCRIPTION_PROMPT = """Here is a snippet from a sample document of type {docset_name}: {document} diff --git a/docugami_kg_rag/helpers/retrieval.py b/docugami_kg_rag/helpers/retrieval.py index 3dfcff9..281c281 100644 --- a/docugami_kg_rag/helpers/retrieval.py +++ b/docugami_kg_rag/helpers/retrieval.py @@ -3,7 +3,7 @@ from langchain.agents.agent_toolkits import create_retriever_tool from langchain.prompts import ChatPromptTemplate -from langchain.schema import Document, StrOutputParser +from langchain.schema import BaseRetriever, Document, StrOutputParser from langchain.tools.base import BaseTool from langchain.vectorstores import Chroma @@ -25,6 +25,22 @@ ) +def get_retriever_for_docset(docset_state: LocalIndexState) -> BaseRetriever: + """ + Gets a retriever for a docset. Chunks are in the vector store, and full documents + are in the store inside the local state. + """ + chunk_vectorstore = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=EMBEDDINGS) + + return FusedSummaryRetriever( + vectorstore=chunk_vectorstore, + parent_doc_store=docset_state.chunks_by_id, + full_doc_summary_store=docset_state.full_doc_summaries_by_id, + search_kwargs={"k": RETRIEVER_K}, + search_type=SearchType.mmr, + ) + + def docset_name_to_direct_retriever_tool_function_name(name: str) -> str: """ Converts a docset name to a direct retriever tool function name. @@ -75,19 +91,10 @@ def chunks_to_direct_retriever_tool_description(name: str, chunks: List[Document def get_retrieval_tool_for_docset(docset_state: LocalIndexState) -> Optional[BaseTool]: """ - Chunks are in the vector store, and full documents are in the store inside the local state + Gets a retrieval tool for an agent. """ - chunk_vectorstore = Chroma(persist_directory=CHROMA_DIRECTORY, embedding_function=EMBEDDINGS) - - retriever = FusedSummaryRetriever( - vectorstore=chunk_vectorstore, - parent_doc_store=docset_state.chunks_by_id, - full_doc_summary_store=docset_state.full_doc_summaries_by_id, - search_kwargs={"k": RETRIEVER_K}, - search_type=SearchType.mmr, - ) - + retriever = get_retriever_for_docset(docset_state=docset_state) return create_retriever_tool( retriever=retriever, name=docset_state.retrieval_tool_function_name, diff --git a/evals/run-evals.ipynb b/evals/run-evals.ipynb index 1f6fa0a..cd9fa37 100644 --- a/evals/run-evals.ipynb +++ b/evals/run-evals.ipynb @@ -11,19 +11,25 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Set up Eval Dataset" + "## Set up Eval" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "from pathlib import Path\n", "\n", - "DOCSET_NAME = \"Earnings Calls Evaluation\"\n", + "# Important: Create your OpenAI assistant via https://platform.openai.com/playground\n", + "# and put the assistant ID here. Make sure you upload the identical set of\n", + "# files listed below (these files will be uploaded automatically to Docugami)\n", + "OPENAI_ASSISTANT_ID = \"asst_g837jjwr6Ohgk2EWfQOKTcPg\"\n", + "\n", + "DOCSET_NAME = \"Earnings Calls Evaluation 12-06-2023\"\n", + "FILES_DIR = Path(os.getcwd()) / \"v1/docs\"\n", "FILE_NAMES = [\n", " \"Q1 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf\",\n", " \"Q1 2023 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf\",\n", @@ -36,10 +42,12 @@ " \"Q4 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf\",\n", " \"Q3 FY23 Microsoft Corp Earnings Call.pdf\",\n", "]\n", - "\n", - "FILES_DIR = Path(os.getcwd()) / \"v1/docs\"\n", "GROUND_TRUTH_CSV = Path(os.getcwd()) / \"v1/ground-truth-earning_calls.csv\"\n", "\n", + "# We will run each experiment multiple times and average, \n", + "# since results vary slightly over runs\n", + "PER_EXPERIMENT_RUN_COUNT = 5\n", + "\n", "# Note: Please specify ~6 (or more!) similar files to process together as a document set\n", "# This is currently a requirement for Docugami to automatically detect motifs\n", "# across the document set to generate a semantic XML Knowledge Graph.\n", @@ -48,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -85,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -102,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -130,9 +138,29 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'Q1 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmp19ctz_zz',\n", + " 'Q1 2023 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmp1fcb78wn',\n", + " 'Q2 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmpw3zclms2',\n", + " 'Q2 2023 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmprqly0der',\n", + " 'Q3 2021 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmp00rntcqk',\n", + " 'Q3 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmpboz8mq6c',\n", + " 'Q3 2023 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmpseqe4ojt',\n", + " 'Q4 2020 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmph9dhg7fi',\n", + " 'Q4 2022 Snowflake Inc. Earnings Call - Snowflake Inc - BamSEC.pdf': '/tmp/tmp9jufpgk0',\n", + " 'Q3 FY23 Microsoft Corp Earnings Call.pdf': '/tmp/tmp14olooto'}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Wait for files to finish processing (OCR, and zero-shot creation of XML knowledge graph)\n", "\n", @@ -142,9 +170,25 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Indexing Earnings Calls Evaluation 12-06-2023 (ID: l4ebpbn3ugk0)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Creating full document summaries in batches: 100%|██████████| 1/1 [03:10<00:00, 190.34s/it]\n", + "Creating chunk summaries in batches: 11%|█ | 4/38 [01:35<13:24, 23.65s/it]" + ] + } + ], "source": [ "# Run indexing\n", "from docugami_kg_rag.helpers.indexing import index_docset\n", @@ -219,41 +263,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Upload files to OpenAI" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from openai import files\n", + "### Create OpenAI Agent\n", "\n", - "existing_files = files.list().data\n", - "existing_file_names = [f.filename for f in existing_files]\n", - "uploaded_files = []\n", - "for file_name in FILE_NAMES:\n", - " if file_name in existing_file_names:\n", - " # file was previously uploaded\n", - " uploaded_files.append([f for f in existing_files if f.filename == file_name][0])\n", - " else:\n", - " # upload\n", - " file_path = FILES_DIR / file_name\n", - " file = files.create(file=file_path, purpose='assistants')\n", - " uploaded_files.append(file)\n", - "\n", - "file_ids=[f.id for f in uploaded_files]\n", - "\n", - "uploaded_files\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Create OpenAI Assistant Agent" + "Please go to https://platform.openai.com/playground and create your agent. " ] }, { @@ -263,20 +275,11 @@ "outputs": [], "source": [ "from langchain.agents.openai_assistant import OpenAIAssistantRunnable\n", - "openai_agent = OpenAIAssistantRunnable.create_assistant(\n", - " name=\"Earnings Call Assistant\",\n", - " instructions=\"An assistant that specializes in answering questions based only on the given knowledge base of earnings calls\",\n", - " tools=[{\"type\": \"retrieval\"}],\n", - " model=\"gpt-4-1106-preview\",\n", - ")\n", - "openai_agent.as_agent = True\n", "\n", - "def predict_openai_agent(input: dict) -> dict:\n", + "def predict_openai_agent(input: dict, config: dict = None) -> dict:\n", + " openai_agent = OpenAIAssistantRunnable(assistant_id=OPENAI_ASSISTANT_ID, as_agent=True).with_config(config)\n", " question = input[\"question\"]\n", - " result = openai_agent.invoke({\n", - " \"content\": question,\n", - " \"file_ids\": file_ids,\n", - " })\n", + " result = openai_agent.invoke({\"content\": question})\n", "\n", " return result.return_values[\"output\"]" ] @@ -307,9 +310,10 @@ "import uuid\n", "from langsmith.client import Client\n", "from langchain.smith import RunEvalConfig\n", + "from langchain.globals import set_llm_cache, get_llm_cache\n", "\n", "eval_config = RunEvalConfig(\n", - " evaluators=[\"cot_qa\"],\n", + " evaluators=[\"qa\"],\n", ")\n", "\n", "def run_eval(eval_func, eval_run_name):\n", @@ -328,14 +332,24 @@ "\n", "# Experiments\n", "agent_map = {\n", + " # \"openai_assistant_retrieval\": predict_openai_agent,\n", " \"docugami_kg_rag_zero_shot\": predict_docugami_agent,\n", - " \"openai_assistant_retrieval\": predict_openai_agent,\n", "}\n", "\n", - "for _ in range(10):\n", - " run_id = str(uuid.uuid4())\n", - " for project_name, agent in agent_map.items():\n", - " run_eval(agent, project_name + \"_\" + run_id)" + "try:\n", + " # Disable global cache setting to get fresh results every time for all experiments\n", + " # since no caching or temperature-0 is supported for the openai assistants API and\n", + " # we want to measure under similar conditions\n", + " cache = get_llm_cache()\n", + " set_llm_cache(None)\n", + "\n", + " for i in range(PER_EXPERIMENT_RUN_COUNT):\n", + " run_id = str(uuid.uuid4())\n", + " for project_name, agent in agent_map.items():\n", + " run_eval(agent, project_name + \"_\" + run_id)\n", + "finally:\n", + " # Revert cache setting to global default\n", + " set_llm_cache(cache)\n" ] } ],