Skip to content

Commit

Permalink
Add redis support, and other updates (#15)
Browse files Browse the repository at this point in the history
* Refactor config and prompts

* Simplify agent executor code

* refactored vector store abstraction

* Added redis support

* default back to chroma

---------

Co-authored-by: Taqi Jaffri <[email protected]>
  • Loading branch information
tjaffri and Taqi Jaffri authored Jan 23, 2024
1 parent a7405bf commit e9f032c
Show file tree
Hide file tree
Showing 12 changed files with 1,232 additions and 366 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
format:
poetry run ruff check . --fix
poetry run black .

lint:
Expand Down
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This template contains a reference architecture for Retrieval Augmented Generati
You need to set some required environment variables before using your new app based on this template. These are used to index as well as run the application, and exceptions are raised if the following required environment variables are not set:

1. `OPENAI_API_KEY`: from the OpenAI platform.
2. `DOCUGAMI_API_KEY`: from the [Docugami Developer Playground](https://help.docugami.com/home/docugami-api)
1. `DOCUGAMI_API_KEY`: from the [Docugami Developer Playground](https://help.docugami.com/home/docugami-api)

```shell
export OPENAI_API_KEY=...
Expand All @@ -28,8 +28,8 @@ Finally, make sure that you run `poetry install --all-extras` (or select a speci
Before you use this template, you must have some documents already processed in Docugami. Here's what you need to get started:

1. Create a [Docugami workspace](https://app.docugami.com/) (free trials available)
2. Create an access token via the Developer Playground for your workspace. [Detailed instructions](https://help.docugami.com/home/docugami-api).
3. Add your documents to Docugami for processing. There are two ways to do this:
1. Create an access token via the Developer Playground for your workspace. [Detailed instructions](https://help.docugami.com/home/docugami-api).
1. Add your documents to Docugami for processing. There are two ways to do this:
- Upload via the simple Docugami web experience. [Detailed instructions](https://help.docugami.com/home/adding-documents).
- Upload via the Docugami API, specifically the [documents](https://api-docs.docugami.com/#tag/documents/operation/upload-document) endpoint. Code samples are available for python and JavaScript or you can use the [docugami](https://pypi.org/project/docugami/) python library.

Expand Down Expand Up @@ -105,3 +105,18 @@ from langserve.client import RemoteRunnable

runnable = RemoteRunnable("http://localhost:8000/docugami-kg-rag")
```

# Advanced Configuration

## Using Local GPU
Optionally, if using local embeddings or llms in `config.py`, make sure your local CUDA runtime is updated. You can run `torch.cuda.is_available()` in a python REPL to make sure, and if you need to install a specific version for your local CUDA driver you can run something like `poetry run pip3 install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117` to update it.

## Using Redis

Under `config.py` you can configure the vector store to use Redis. See documentation here:

One of the things you need to specify is the REDIS_URL. You may have an instance already running that you can point to, or for development you may want to deploy Redis locally:

`docker run -d -p 6379:6379 -p 8001:8001 redis/redis-stack:latest`

See documentation [here](https://python.langchain.com/docs/integrations/vectorstores/redis#redis-connection-url-examples) for how to configure the REDIS_URL
58 changes: 10 additions & 48 deletions docugami_kg_rag/chain.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
import sys
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
)
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema.messages import AIMessage, HumanMessage
from langchain.schema.runnable import Runnable, RunnableLambda, RunnableParallel
from langchain.schema.runnable import Runnable, RunnableLambda
from langchain.tools.base import BaseTool
from langchain.tools.render import format_tool_to_openai_function

from docugami_kg_rag.config import LARGE_CONTEXT_LLM
from docugami_kg_rag.helpers.state import read_all_local_index_state
from docugami_kg_rag.config import LARGE_CONTEXT_LLM, USE_REPORTS
from docugami_kg_rag.helpers.indexing import read_all_local_index_state
from docugami_kg_rag.helpers.prompts import ASSISTANT_SYSTEM_MESSAGE
from docugami_kg_rag.helpers.reports import get_retrieval_tool_for_report
from docugami_kg_rag.helpers.retrieval import get_retrieval_tool_for_docset


def _get_tools(use_reports=True) -> List[BaseTool]:
def _get_tools(use_reports=USE_REPORTS) -> List[BaseTool]:
"""
Build retrieval tools.
"""
Expand All @@ -32,7 +28,7 @@ def _get_tools(use_reports=True) -> List[BaseTool]:
report_retrieval_tools: List[BaseTool] = []
for docset_id in local_state:
docset_state = local_state[docset_id]
direct_retrieval_tool = get_retrieval_tool_for_docset(docset_state)
direct_retrieval_tool = get_retrieval_tool_for_docset(docset_id, docset_state)
if direct_retrieval_tool:
# Direct retrieval tool for each indexed docset (direct KG-RAG against semantic XML)
docset_retrieval_tools.append(direct_retrieval_tool)
Expand Down Expand Up @@ -73,50 +69,16 @@ def _format_chat_history(chat_history: List[Tuple[str, str]]):
)


agent = (
RunnableParallel(
{
"input": lambda x: x["input"], # type: ignore
"chat_history": lambda x: _format_chat_history(x["chat_history"]), # type: ignore
"agent_scratchpad": lambda x: format_to_openai_functions(x["intermediate_steps"]), # type: ignore
"functions": lambda x: [format_tool_to_openai_function(tool) for tool in _get_tools("use_reports" in x and x["use_reports"] is True)], # type: ignore
}
)
| {
"input": prompt,
"functions": lambda x: x["functions"],
}
| _llm_with_tools
| OpenAIFunctionsAgentOutputParser()
)


class AgentInput(BaseModel):
input: str = ""
use_reports: bool = Field(
default=False,
extra={"widget": {"type": "bool", "input": "input", "output": "output"}},
)
chat_history: Optional[List[Tuple[str, str]]] = Field(
..., extra={"widget": {"type": "chat", "input": "input", "output": "output"}}
)


chain = AgentExecutor(
agent=agent, # type: ignore
tools=_get_tools(),
).with_types(
input_type=AgentInput, # type: ignore
)
agent = create_openai_tools_agent(LARGE_CONTEXT_LLM, _get_tools(), prompt)
chain = AgentExecutor(agent=agent, tools=_get_tools()) # type: ignore

if __name__ == "__main__":
if sys.gettrace():
# This code will only run if a debugger is attached

chain.invoke(
{
"input": "What was the question from Barclays in the Q2 2023 earnings call?",
"use_reports": False,
"input": "What happened in Yelm, Washington?",
"chat_history": [],
}
)
129 changes: 111 additions & 18 deletions docugami_kg_rag/config.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
import os
from pathlib import Path

from langchain_core.vectorstores import VectorStore

from langchain.cache import SQLiteCache
from langchain.globals import set_llm_cache


##### <LLMs and Embeddings>
# OpenAI models and Embeddings
# Reference: https://platform.openai.com/docs/models
from langchain_community.chat_models.openai import ChatOpenAI
from langchain_community.embeddings.openai import OpenAIEmbeddings

LARGE_CONTEXT_LLM = ChatOpenAI(temperature=0, model="gpt-4-1106-preview") # 128k tokens
SMALL_CONTEXT_LLM = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-1106") # 16k tokens
EMBEDDINGS = OpenAIEmbeddings(model="text-embedding-ada-002")
##### </LLMs and Embeddings>

##### <Vector Store>
CHROMA_DIRECTORY = "/tmp/docugami/chroma_db"
os.makedirs(Path(CHROMA_DIRECTORY).parent, exist_ok=True)
##### </Vector Store>

DOCUGAMI_API_KEY = os.environ.get("DOCUGAMI_API_KEY")
if not DOCUGAMI_API_KEY:
raise Exception("Please set the DOCUGAMI_API_KEY environment variable")
Expand All @@ -36,13 +21,121 @@
os.makedirs(Path(LOCAL_LLM_CACHE_DB_FILE).parent, exist_ok=True)
set_llm_cache(SQLiteCache(database_path=LOCAL_LLM_CACHE_DB_FILE))

USE_REPORTS = True

##### <LLMs and Embeddings>
# OpenAI models and Embeddings
# Reference: https://platform.openai.com/docs/models
from langchain_openai import ChatOpenAI
from langchain_openai import OpenAIEmbeddings

if "OPENAI_API_KEY" not in os.environ:
raise Exception("OPENAI_API_KEY environment variable not set")

LARGE_CONTEXT_LLM = ChatOpenAI(temperature=0, model="gpt-4-1106-preview", cache=True) # 128k tokens
SMALL_CONTEXT_LLM = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-1106", cache=True) # 16k tokens
EMBEDDINGS = OpenAIEmbeddings(model="text-embedding-ada-002")

# Lengths for the Docugami loader are in terms of characters, 1 token ~= 4 chars in English
# Reference: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
MIN_LENGTH_TO_SUMMARIZE = 2048 # chunks and docs below this length are embedded as-is
MAX_FULL_DOCUMENT_TEXT_LENGTH = 1024 * 56 # ~14k tokens
MAX_CHUNK_TEXT_LENGTH = 1024 * 26 # ~6.5k tokens
MIN_CHUNK_TEXT_LENGTH = 1024 * 6 # ~1.5k tokens
SUB_CHUNK_TABLES = False
INCLUDE_XML_TAGS = True
INCLUDE_XML_TAGS = False
PARENT_HIERARCHY_LEVELS = 2
RETRIEVER_K = 8

BATCH_SIZE = 16

# FireworksAI models and local embeddings
# Reference: https://fireworks.ai/models
# Reference: https://huggingface.co/models
# import torch
# from langchain.chat_models.fireworks import ChatFireworks
# from langchain_community.embeddings import HuggingFaceEmbeddings

# if "FIREWORKS_API_KEY" not in os.environ:
# raise Exception("FIREWORKS_API_KEY environment variable not set")
# LARGE_CONTEXT_LLM = ChatFireworks(
# model="accounts/fireworks/models/mixtral-8x7b-instruct",
# model_kwargs={"temperature": 0, "max_tokens": 1024},
# cache=True,
# ) # 128k tokens
# SMALL_CONTEXT_LLM = LARGE_CONTEXT_LLM # Use the same model for large and small context tasks
# device = "cpu"
# if torch.cuda.is_available():
# device = "cuda"

# EMBEDDINGS = HuggingFaceEmbeddings(
# model_name="sentence-transformers/all-mpnet-base-v2",
# model_kwargs={"device": device},
# )

# # Lengths for the Docugami loader are in terms of characters, 1 token ~= 4 chars in English
# # Reference: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
# MIN_LENGTH_TO_SUMMARIZE = 2048 # chunks and docs below this length are embedded as-is
# MAX_FULL_DOCUMENT_TEXT_LENGTH = 1024 * 56 # ~14k tokens
# MAX_CHUNK_TEXT_LENGTH = 1024 * 26 # ~6.5k tokens
# MIN_CHUNK_TEXT_LENGTH = 1024 * 6 # ~1.5k tokens
# SUB_CHUNK_TABLES = False
# INCLUDE_XML_TAGS = False
# PARENT_HIERARCHY_LEVELS = 2
# RETRIEVER_K = 8

# BATCH_SIZE = 16
##### </LLMs and Embeddings>

##### <Vector Store>
# ChromaDB
# Reference: https://python.langchain.com/docs/integrations/vectorstores/chroma
from langchain_community.vectorstores.chroma import Chroma
import chromadb

CHROMA_DIRECTORY = Path("/tmp/docugami/chroma_db")


def get_vector_store_index(docset_id: str) -> VectorStore:
return Chroma(
collection_name=docset_id,
persist_directory=str(CHROMA_DIRECTORY.absolute()),
embedding_function=EMBEDDINGS,
)


def vector_store_index_exists(docset_id: str) -> bool:
persistent_client = chromadb.PersistentClient(path=str(CHROMA_DIRECTORY.absolute()))
collections = persistent_client.list_collections()
for c in collections:
if c.name == docset_id:
return True

return False


def del_vector_store_index(docset_id: str):
persistent_client = chromadb.PersistentClient(path=str(CHROMA_DIRECTORY.absolute()))
persistent_client.delete_collection(docset_id)


# Redis
# Reference: https://python.langchain.com/docs/integrations/vectorstores/redis
# from langchain_community.vectorstores.redis.base import Redis, check_index_exists

# REDIS_URL = "redis://localhost:6379"


# def get_vector_store_index(docset_id: str) -> VectorStore:
# return Redis(redis_url=REDIS_URL, index_name=docset_id, embedding=EMBEDDINGS)


# def vector_store_index_exists(docset_id: str) -> bool:
# index = get_vector_store_index(docset_id)
# check_index_exists(index, docset_id)

# def del_vector_store_index(docset_id: str):
# Redis.drop_index(docset_id, True, redis_url=REDIS_URL)


##### </Vector Store>
19 changes: 14 additions & 5 deletions docugami_kg_rag/helpers/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,26 @@
BATCH_SIZE,
INCLUDE_XML_TAGS,
MAX_CHUNK_TEXT_LENGTH,
MAX_FULL_DOCUMENT_TEXT_LENGTH,
MIN_LENGTH_TO_SUMMARIZE,
SMALL_CONTEXT_LLM,
LARGE_CONTEXT_LLM,
)
from docugami_kg_rag.helpers.prompts import (
ASSISTANT_SYSTEM_MESSAGE,
CREATE_CHUNK_SUMMARY_SYSTEM_MESSAGE,
CREATE_FULL_DOCUMENT_SUMMARY_PROMPT,
CREATE_CHUNK_SUMMARY_PROMPT,
CREATE_FULL_DOCUMENT_SUMMARY_SYSTEM_MESSAGE,
)


def _build_summary_mappings(
docs_by_id: Dict[str, Document],
system_message: str,
prompt_template: str,
llm: BaseChatModel = SMALL_CONTEXT_LLM,
summarize_length_threshold=2048,
min_length_to_summarize=MIN_LENGTH_TO_SUMMARIZE,
max_length_cutoff=MAX_CHUNK_TEXT_LENGTH,
label="summaries",
) -> Dict[str, Document]:
"""
Expand All @@ -47,15 +52,15 @@ def _build_summary_mappings(
batch_input = [
{
"format": format,
"document": doc.page_content[:MAX_CHUNK_TEXT_LENGTH],
"document": doc.page_content[:max_length_cutoff],
}
for _, doc in batch
]

summarize_chain = (
ChatPromptTemplate.from_messages(
[
("system", ASSISTANT_SYSTEM_MESSAGE),
("system", system_message),
("human", prompt_template),
]
)
Expand All @@ -66,7 +71,7 @@ def _build_summary_mappings(

# Build meta chain that only summarizes inputs larger than threshold
chain = RunnableBranch(
(lambda x: len(x["document"]) > summarize_length_threshold, summarize_chain), # type: ignore
(lambda x: len(x["document"]) > min_length_to_summarize, summarize_chain), # type: ignore
noop_chain,
)

Expand Down Expand Up @@ -94,8 +99,10 @@ def build_full_doc_summary_mappings(docs_by_id: Dict[str, Document]) -> Dict[str

return _build_summary_mappings(
docs_by_id=docs_by_id,
system_message=CREATE_FULL_DOCUMENT_SUMMARY_SYSTEM_MESSAGE,
prompt_template=CREATE_FULL_DOCUMENT_SUMMARY_PROMPT,
llm=LARGE_CONTEXT_LLM,
max_length_cutoff=MAX_FULL_DOCUMENT_TEXT_LENGTH,
label="full document summaries",
)

Expand All @@ -107,7 +114,9 @@ def build_chunk_summary_mappings(docs_by_id: Dict[str, Document]) -> Dict[str, D

return _build_summary_mappings(
docs_by_id=docs_by_id,
system_message=CREATE_CHUNK_SUMMARY_SYSTEM_MESSAGE,
prompt_template=CREATE_CHUNK_SUMMARY_PROMPT,
llm=SMALL_CONTEXT_LLM,
max_length_cutoff=MAX_CHUNK_TEXT_LENGTH,
label="chunk summaries",
)
Loading

0 comments on commit e9f032c

Please sign in to comment.