Skip to content

Commit

Permalink
Batch and cache updates (#22)
Browse files Browse the repository at this point in the history
* specify batch size for each config

* Swtich to gptcache and turn on conversational tools by default

* lint

* lint fixes

---------

Co-authored-by: Taqi Jaffri <[email protected]>
  • Loading branch information
tjaffri and Taqi Jaffri authored May 7, 2024
1 parent b2c6bbe commit 6547623
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 28 deletions.
37 changes: 29 additions & 8 deletions docugami_kg_rag/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import hashlib
import os
from pathlib import Path

from langchain_community.cache import SQLiteCache
from gptcache import Cache
from gptcache.manager.factory import manager_factory
from gptcache.processor.pre import get_prompt
from langchain.cache import GPTCache
from langchain_core.globals import set_llm_cache

# from docugami_kg_rag.config.fireworksai_llama3 import *
# from docugami_kg_rag.config.huggingface import *
from docugami_kg_rag.config.chromadb import *
from docugami_kg_rag.config.openai import *
from docugami_kg_rag.config.fireworksai_llama3 import *

# from docugami_kg_rag.config.fireworksai_mixtral import *
# from docugami_kg_rag.config.openai import *
from docugami_kg_rag.config.huggingface import *

# from docugami_kg_rag.config.redis import *

Expand All @@ -23,11 +29,26 @@
INDEXING_LOCAL_REPORT_DBS_ROOT = os.environ.get("INDEXING_LOCAL_REPORT_DBS_ROOT", "/tmp/docugami/report_dbs")
os.makedirs(Path(INDEXING_LOCAL_REPORT_DBS_ROOT).parent, exist_ok=True)

LOCAL_LLM_CACHE_DB_FILE = os.environ.get("LOCAL_LLM_CACHE", "/tmp/docugami/.langchain.db")
os.makedirs(Path(LOCAL_LLM_CACHE_DB_FILE).parent, exist_ok=True)
set_llm_cache(SQLiteCache(database_path=LOCAL_LLM_CACHE_DB_FILE))
LOCAL_LLM_CACHE_DIR = os.environ.get("LOCAL_LLM_CACHE", "/tmp/docugami/langchain_cache")
os.makedirs(Path(LOCAL_LLM_CACHE_DIR).parent, exist_ok=True)


def get_hashed_name(name: str) -> str:
return hashlib.sha256(name.encode()).hexdigest()


def init_gptcache(cache_obj: Cache, llm: str) -> None:
hashed_llm = get_hashed_name(llm)
hashed_llm_dir = Path(LOCAL_LLM_CACHE_DIR) / hashed_llm
cache_obj.init(
pre_embedding_func=get_prompt,
data_manager=manager_factory(manager="map", data_dir=str(hashed_llm_dir.absolute())),
)


set_llm_cache(GPTCache(init_gptcache))

EXAMPLES_PATH = Path(__file__).parent.parent / "green_examples"

DEFAULT_USE_REPORTS = True
DEFAULT_USE_CONVERSATIONAL_TOOLS = False
DEFAULT_USE_CONVERSATIONAL_TOOLS = True
5 changes: 3 additions & 2 deletions docugami_kg_rag/config/fireworksai_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
)
SMALL_CONTEXT_INSTRUCT_LLM = LARGE_CONTEXT_INSTRUCT_LLM # Use the same model for large and small context tasks
SQL_GEN_LLM = LARGE_CONTEXT_INSTRUCT_LLM # Use the same model for sql gen
LLM_BATCH_SIZE = 2

# Lengths for the Docugami loader are in terms of characters, 1 token ~= 4 chars in English
# Reference: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
MIN_LENGTH_TO_SUMMARIZE = 2048 # chunks and docs below this length are embedded as-is
MAX_FULL_DOCUMENT_TEXT_LENGTH = int(1024 * 4 * 4.5) # ~4.5k tokens
MAX_CHUNK_TEXT_LENGTH = int(1024 * 4 * 4.5) # ~4.5k tokens
MIN_CHUNK_TEXT_LENGTH = int(1024 * 4 * 1.5) # ~1.5k tokens
MAX_CHUNK_TEXT_LENGTH = int(1024 * 4 * 1) # ~1k tokens
MIN_CHUNK_TEXT_LENGTH = int(1024 * 4 * 0.5) # ~0.5k tokens
SUB_CHUNK_TABLES = False
INCLUDE_XML_TAGS = False
PARENT_HIERARCHY_LEVELS = 2
Expand Down
1 change: 1 addition & 0 deletions docugami_kg_rag/config/fireworksai_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
SMALL_CONTEXT_INSTRUCT_LLM = LARGE_CONTEXT_INSTRUCT_LLM # Use the same model for large and small context tasks
SQL_GEN_LLM = LARGE_CONTEXT_INSTRUCT_LLM # Use the same model for sql gen
LLM_BATCH_SIZE = 32

# Lengths for the Docugami loader are in terms of characters, 1 token ~= 4 chars in English
# Reference: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them
Expand Down
2 changes: 2 additions & 0 deletions docugami_kg_rag/config/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
max_tokens=2 * 1024, # only output tokens
)
SQL_GEN_LLM = SMALL_CONTEXT_INSTRUCT_LLM # Use the same model for sql gen
LLM_BATCH_SIZE = 256

EMBEDDINGS = OpenAIEmbeddings(model="text-embedding-ada-002")

Expand All @@ -36,3 +37,4 @@
INCLUDE_XML_TAGS = True
PARENT_HIERARCHY_LEVELS = 2
RETRIEVER_K = 8

3 changes: 3 additions & 0 deletions docugami_kg_rag/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
INDEXING_LOCAL_REPORT_DBS_ROOT,
INDEXING_LOCAL_STATE_PATH,
LARGE_CONTEXT_INSTRUCT_LLM,
LLM_BATCH_SIZE,
MAX_CHUNK_TEXT_LENGTH,
MAX_FULL_DOCUMENT_TEXT_LENGTH,
MIN_CHUNK_TEXT_LENGTH,
Expand Down Expand Up @@ -143,6 +144,7 @@ def index_docset(docset_id: str, name: str) -> None:
max_length_cutoff=MAX_FULL_DOCUMENT_TEXT_LENGTH,
include_xml_tags=INCLUDE_XML_TAGS,
summarize_document_examples_file=EXAMPLES_PATH / "summarize_document_examples.yaml",
batch_size=LLM_BATCH_SIZE,
)
chunk_summaries_by_id = build_chunk_summary_mappings(
docs_by_id=parent_chunks_by_id,
Expand All @@ -152,6 +154,7 @@ def index_docset(docset_id: str, name: str) -> None:
max_length_cutoff=MAX_FULL_DOCUMENT_TEXT_LENGTH,
include_xml_tags=INCLUDE_XML_TAGS,
summarize_chunk_examples_file=EXAMPLES_PATH / "summarize_chunk_examples.yaml",
batch_size=LLM_BATCH_SIZE,
)

direct_tool_function_name = docset_name_to_direct_retrieval_tool_function_name(name)
Expand Down
49 changes: 33 additions & 16 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ readme = "README.md"
python = ">=3.9,<4.0"
langchain = ">=0.1.17"
langchain-community = ">=0.0.36"
docugami-langchain = "^0.0.9rc2"
docugami-langchain = "^0.0.9rc4"
# docugami-langchain = {git = "https://github.com/docugami/docugami-langchain.git"}
# docugami-langchain = {git = "https://github.com/docugami/docugami-langchain.git", rev = "tjaffri/rel"}
typer = ">=0.9.0"
Expand All @@ -21,6 +21,7 @@ langsmith = ">=0.1.52"
openpyxl = ">=3.1.2"
faiss-cpu = ">=1.8.0"
torch = "2.2.0"
gptcache = "^0.1.43"
rerankers = { extras = ["all"], version = ">=0.2.0" }
langchain-openai = { version = ">=0.1.6", optional = true }
langchain-fireworks = { version = ">=0.1.2", optional = true }
Expand Down Expand Up @@ -65,7 +66,7 @@ exclude = "/temp/"

[[tool.mypy.overrides]]
module = [
"torch", "langchain_openai.*", "langchain_fireworks.*", "docugami_langchain.*", "rerankers.*",
"torch", "langchain_openai.*", "langchain_fireworks.*", "docugami_langchain.*", "rerankers.*", "gptcache.*"
]
ignore_missing_imports = true

Expand Down

0 comments on commit 6547623

Please sign in to comment.