Skip to content

Commit

Permalink
notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
Taqi Jaffri committed Mar 28, 2024
1 parent 578b800 commit fcb108c
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 194 deletions.
16 changes: 9 additions & 7 deletions docugami_kg_rag/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,15 @@ class AgentInput(BaseModel):
)


def agent_output_to_string(streaming_state: AgentState) -> str:
if streaming_state:
react_output = streaming_state.get("generate_re_act")
if react_output:
cited_answer = react_output.get("cited_answer")
if cited_answer and cited_answer.is_final:
return cited_answer.answer
def agent_output_to_string(state: AgentState) -> str:
if state:
streaming_output = state.get("generate_re_act")
if streaming_output:
state = streaming_output

cited_answer = state.get("cited_answer")
if cited_answer and cited_answer.is_final:
return cited_answer.answer

return ""

Expand Down
83 changes: 58 additions & 25 deletions evals/sec-10-q.ipynb → notebooks/eval-sec-10-q.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -43,7 +43,7 @@
"\n",
"DOCSET_NAME = \"SEC 10Q Filings\"\n",
"EVAL_NAME = DOCSET_NAME + \" \" + datetime.now().strftime(\"%Y-%m-%d\")\n",
"FILES_DIR = Path(os.getcwd()) / \"temp/sec-10-q/docs\"\n",
"FILES_DIR = Path(os.getcwd()) / \"temp/sec-10-q/data/v1/docs\"\n",
"FILE_NAMES = [\n",
" \"2022 Q3 AAPL.pdf\",\n",
" \"2022 Q3 AMZN.pdf\",\n",
Expand Down Expand Up @@ -82,7 +82,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -126,7 +126,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -154,9 +154,39 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"{'2022 Q3 AAPL.pdf': '/tmp/tmpwj7qvnpc',\n",
" '2022 Q3 AMZN.pdf': '/tmp/tmp_b1v53gw',\n",
" '2022 Q3 INTC.pdf': '/tmp/tmp7rbwg3oo',\n",
" '2022 Q3 MSFT.pdf': '/tmp/tmplb2qcrzz',\n",
" '2022 Q3 NVDA.pdf': '/tmp/tmpudk140xq',\n",
" '2023 Q1 AAPL.pdf': '/tmp/tmpy49blnkv',\n",
" '2023 Q1 AMZN.pdf': '/tmp/tmp9z7c6swf',\n",
" '2023 Q1 INTC.pdf': '/tmp/tmpfi4rli61',\n",
" '2023 Q1 MSFT.pdf': '/tmp/tmprbk948a0',\n",
" '2023 Q1 NVDA.pdf': '/tmp/tmp779afiom',\n",
" '2023 Q2 AAPL.pdf': '/tmp/tmpn22kjw46',\n",
" '2023 Q2 AMZN.pdf': '/tmp/tmp3fadq9kp',\n",
" '2023 Q2 INTC.pdf': '/tmp/tmpe0gim1ke',\n",
" '2023 Q2 MSFT.pdf': '/tmp/tmpb5mb3x0a',\n",
" '2023 Q2 NVDA.pdf': '/tmp/tmpgc20mcnv',\n",
" '2023 Q3 AAPL.pdf': '/tmp/tmprnrsxhhs',\n",
" '2023 Q3 AMZN.pdf': '/tmp/tmpuatxzleg',\n",
" '2023 Q3 INTC.pdf': '/tmp/tmpgnyhizig',\n",
" '2023 Q3 MSFT.pdf': '/tmp/tmp5dg44pcy',\n",
" '2023 Q3 NVDA.pdf': '/tmp/tmp9dogliao'}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Wait for files to finish processing (OCR, and zero-shot creation of XML knowledge graph)\n",
"\n",
Expand All @@ -166,9 +196,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Torch not installed...\n",
"Loading default rankgpt3 model for language en\n",
"Loading RankGPTRanker model gpt-3.5-turbo\n",
"Indexing SEC 10Q Filings (ID: s85dxu9aie2h)\n"
]
}
],
"source": [
"# Run indexing\n",
"from docugami_kg_rag.indexing import index_docset\n",
Expand All @@ -193,23 +234,15 @@
"metadata": {},
"outputs": [],
"source": [
"from docugami_kg_rag.agent import agent as docugami_agent, _get_tools, AgentInput\n",
"\n",
"def predict_docugami_agent(input: dict) -> dict:\n",
" question = input[\"question\"]\n",
" chain = AgentExecutor(\n",
" agent=docugami_agent,\n",
" tools=_get_tools(),\n",
" ).with_types(\n",
" input_type=AgentInput,\n",
" )\n",
" result = chain.invoke({\n",
" \"input\": question,\n",
" \"use_reports\": False,\n",
" \"chat_history\": [],\n",
" })\n",
"from docugami_kg_rag.agent import agent as docugami_agent\n",
"from langchain_core.messages import HumanMessage\n",
"\n",
" return result[\"output\"]"
"def predict_docugami_agent(question: str) -> str:\n",
" return docugami_agent.invoke(\n",
" {\n",
" \"messages\": [HumanMessage(content=question)],\n",
" }\n",
" )"
]
},
{
Expand All @@ -219,7 +252,7 @@
"outputs": [],
"source": [
"# Test the agent to make sure it is working\n",
"predict_docugami_agent({\"question\": \"How much did Microsoft spend for opex in the latest quarter?\"})"
"predict_docugami_agent(\"How much did Microsoft spend for opex in the latest quarter?\")"
]
},
{
Expand Down
46 changes: 27 additions & 19 deletions evals/eval-csv.ipynb → notebooks/run-csv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,28 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!rm -rf temp\n",
"!git clone https://github.com/docugami/KG-RAG-datasets.git temp"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"import os\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"\n",
"INPUT_CSV_PATH = Path(os.getcwd()) / \"temp/questions.csv\"\n",
"OUTPUT_CSV_PATH = INPUT_CSV_PATH.with_name(INPUT_CSV_PATH.stem + '_answers' + INPUT_CSV_PATH.suffix)"
"INPUT_CSV_PATH = Path(os.getcwd()) / \"temp/sec-10-q/data/raw_questions/questions_mini.csv\"\n",
"EVAL_NAME = INPUT_CSV_PATH.stem + \"_\" + datetime.now().strftime(\"%Y-%m-%d\")\n",
"OUTPUT_CSV_PATH = INPUT_CSV_PATH.with_name(EVAL_NAME + \"_answers\" + INPUT_CSV_PATH.suffix)"
]
},
{
Expand All @@ -32,52 +44,48 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from docugami_langchain.agents import ReActAgent\n",
"from docugami_kg_rag.agent import agent as docugami_agent\n",
"from langchain_core.messages import HumanMessage\n",
"\n",
"\n",
"def predict_docugami_agent(question: str) -> str:\n",
" result = docugami_agent.invoke(\n",
" return docugami_agent.invoke(\n",
" {\n",
" \"question\": question,\n",
" \"chat_history\": [],\n",
" \"agent_outcome\": None,\n",
" \"intermediate_steps\": [],\n",
" \"messages\": [HumanMessage(content=question)],\n",
" }\n",
" )\n",
"\n",
" return ReActAgent.to_human_readable(result)"
" )"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output CSV created at: /root/Source/github/langchain-template-docugami-kg-rag/evals/temp/questions_answers.csv\n"
"Output CSV created at: /root/Source/github/langchain-template-docugami-kg-rag/notebooks/temp/sec-10-q/data/raw_questions/questions_mini_2024-03-28_answers.csv\n"
]
}
],
"source": [
"os.environ['LANGCHAIN_PROJECT'] = 'adp_eval_3_11'\n",
"os.environ['LANGCHAIN_PROJECT'] = EVAL_NAME\n",
"\n",
"# Eval the CSV\n",
"df = pd.read_csv(INPUT_CSV_PATH)\n",
"if \"question\" in df.columns:\n",
"if \"Question\" in df.columns:\n",
" # Apply the predict function to each question and create a new column for the answers\n",
" df[\"answer\"] = df[\"question\"].apply(predict_docugami_agent)\n",
" df[\"Answer\"] = df[\"Question\"].apply(predict_docugami_agent)\n",
" # Write the dataframe with questions and answers to the output CSV\n",
" df.to_csv(OUTPUT_CSV_PATH, index=False)\n",
" print(f\"Output CSV created at: {OUTPUT_CSV_PATH}\")\n",
"else:\n",
" print(\"Error: The 'question' column does not exist in the input CSV.\")"
" print(\"Error: The 'Question' column does not exist in the input CSV.\")"
]
}
],
Expand Down
Loading

0 comments on commit fcb108c

Please sign in to comment.