diff --git a/docugami_kg_rag/config/__init__.py b/docugami_kg_rag/config/__init__.py index f69d492..5cbb1a8 100644 --- a/docugami_kg_rag/config/__init__.py +++ b/docugami_kg_rag/config/__init__.py @@ -1,13 +1,19 @@ +import hashlib import os from pathlib import Path -from langchain_community.cache import SQLiteCache +from gptcache import Cache +from gptcache.manager.factory import manager_factory +from gptcache.processor.pre import get_prompt +from langchain.cache import GPTCache from langchain_core.globals import set_llm_cache -# from docugami_kg_rag.config.fireworksai_llama3 import * -# from docugami_kg_rag.config.huggingface import * from docugami_kg_rag.config.chromadb import * -from docugami_kg_rag.config.openai import * +from docugami_kg_rag.config.fireworksai_llama3 import * + +# from docugami_kg_rag.config.fireworksai_mixtral import * +# from docugami_kg_rag.config.openai import * +from docugami_kg_rag.config.huggingface import * # from docugami_kg_rag.config.redis import * @@ -23,11 +29,26 @@ INDEXING_LOCAL_REPORT_DBS_ROOT = os.environ.get("INDEXING_LOCAL_REPORT_DBS_ROOT", "/tmp/docugami/report_dbs") os.makedirs(Path(INDEXING_LOCAL_REPORT_DBS_ROOT).parent, exist_ok=True) -LOCAL_LLM_CACHE_DB_FILE = os.environ.get("LOCAL_LLM_CACHE", "/tmp/docugami/.langchain.db") -os.makedirs(Path(LOCAL_LLM_CACHE_DB_FILE).parent, exist_ok=True) -set_llm_cache(SQLiteCache(database_path=LOCAL_LLM_CACHE_DB_FILE)) +LOCAL_LLM_CACHE_DIR = os.environ.get("LOCAL_LLM_CACHE", "/tmp/docugami/langchain_cache") +os.makedirs(Path(LOCAL_LLM_CACHE_DIR).parent, exist_ok=True) + + +def get_hashed_name(name: str) -> str: + return hashlib.sha256(name.encode()).hexdigest() + + +def init_gptcache(cache_obj: Cache, llm: str) -> None: + hashed_llm = get_hashed_name(llm) + hashed_llm_dir = Path(LOCAL_LLM_CACHE_DIR) / hashed_llm + cache_obj.init( + pre_embedding_func=get_prompt, + data_manager=manager_factory(manager="map", data_dir=str(hashed_llm_dir.absolute())), + ) + + +set_llm_cache(GPTCache(init_gptcache)) EXAMPLES_PATH = Path(__file__).parent.parent / "green_examples" DEFAULT_USE_REPORTS = True -DEFAULT_USE_CONVERSATIONAL_TOOLS = False +DEFAULT_USE_CONVERSATIONAL_TOOLS = True diff --git a/docugami_kg_rag/config/fireworksai_llama3.py b/docugami_kg_rag/config/fireworksai_llama3.py index d224e1e..196ef87 100644 --- a/docugami_kg_rag/config/fireworksai_llama3.py +++ b/docugami_kg_rag/config/fireworksai_llama3.py @@ -16,13 +16,14 @@ ) SMALL_CONTEXT_INSTRUCT_LLM = LARGE_CONTEXT_INSTRUCT_LLM # Use the same model for large and small context tasks SQL_GEN_LLM = LARGE_CONTEXT_INSTRUCT_LLM # Use the same model for sql gen +LLM_BATCH_SIZE = 2 # Lengths for the Docugami loader are in terms of characters, 1 token ~= 4 chars in English # Reference: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them MIN_LENGTH_TO_SUMMARIZE = 2048 # chunks and docs below this length are embedded as-is MAX_FULL_DOCUMENT_TEXT_LENGTH = int(1024 * 4 * 4.5) # ~4.5k tokens -MAX_CHUNK_TEXT_LENGTH = int(1024 * 4 * 4.5) # ~4.5k tokens -MIN_CHUNK_TEXT_LENGTH = int(1024 * 4 * 1.5) # ~1.5k tokens +MAX_CHUNK_TEXT_LENGTH = int(1024 * 4 * 1) # ~1k tokens +MIN_CHUNK_TEXT_LENGTH = int(1024 * 4 * 0.5) # ~0.5k tokens SUB_CHUNK_TABLES = False INCLUDE_XML_TAGS = False PARENT_HIERARCHY_LEVELS = 2 diff --git a/docugami_kg_rag/config/fireworksai_mixtral.py b/docugami_kg_rag/config/fireworksai_mixtral.py index 42ebe9e..fbe9029 100644 --- a/docugami_kg_rag/config/fireworksai_mixtral.py +++ b/docugami_kg_rag/config/fireworksai_mixtral.py @@ -16,6 +16,7 @@ ) SMALL_CONTEXT_INSTRUCT_LLM = LARGE_CONTEXT_INSTRUCT_LLM # Use the same model for large and small context tasks SQL_GEN_LLM = LARGE_CONTEXT_INSTRUCT_LLM # Use the same model for sql gen +LLM_BATCH_SIZE = 32 # Lengths for the Docugami loader are in terms of characters, 1 token ~= 4 chars in English # Reference: https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them diff --git a/docugami_kg_rag/config/openai.py b/docugami_kg_rag/config/openai.py index 10718dd..73a801b 100644 --- a/docugami_kg_rag/config/openai.py +++ b/docugami_kg_rag/config/openai.py @@ -21,6 +21,7 @@ max_tokens=2 * 1024, # only output tokens ) SQL_GEN_LLM = SMALL_CONTEXT_INSTRUCT_LLM # Use the same model for sql gen +LLM_BATCH_SIZE = 256 EMBEDDINGS = OpenAIEmbeddings(model="text-embedding-ada-002") @@ -36,3 +37,4 @@ INCLUDE_XML_TAGS = True PARENT_HIERARCHY_LEVELS = 2 RETRIEVER_K = 8 + diff --git a/docugami_kg_rag/indexing.py b/docugami_kg_rag/indexing.py index ff47e34..5a65ec7 100644 --- a/docugami_kg_rag/indexing.py +++ b/docugami_kg_rag/indexing.py @@ -34,6 +34,7 @@ INDEXING_LOCAL_REPORT_DBS_ROOT, INDEXING_LOCAL_STATE_PATH, LARGE_CONTEXT_INSTRUCT_LLM, + LLM_BATCH_SIZE, MAX_CHUNK_TEXT_LENGTH, MAX_FULL_DOCUMENT_TEXT_LENGTH, MIN_CHUNK_TEXT_LENGTH, @@ -143,6 +144,7 @@ def index_docset(docset_id: str, name: str) -> None: max_length_cutoff=MAX_FULL_DOCUMENT_TEXT_LENGTH, include_xml_tags=INCLUDE_XML_TAGS, summarize_document_examples_file=EXAMPLES_PATH / "summarize_document_examples.yaml", + batch_size=LLM_BATCH_SIZE, ) chunk_summaries_by_id = build_chunk_summary_mappings( docs_by_id=parent_chunks_by_id, @@ -152,6 +154,7 @@ def index_docset(docset_id: str, name: str) -> None: max_length_cutoff=MAX_FULL_DOCUMENT_TEXT_LENGTH, include_xml_tags=INCLUDE_XML_TAGS, summarize_chunk_examples_file=EXAMPLES_PATH / "summarize_chunk_examples.yaml", + batch_size=LLM_BATCH_SIZE, ) direct_tool_function_name = docset_name_to_direct_retrieval_tool_function_name(name) diff --git a/poetry.lock b/poetry.lock index 1bd3179..eff3f7f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -345,7 +345,7 @@ virtualenv = ["virtualenv (>=20.0.35)"] name = "cachetools" version = "5.3.3" description = "Extensible memoizing collections and decorators" -optional = true +optional = false python-versions = ">=3.7" files = [ {file = "cachetools-5.3.3-py3-none-any.whl", hash = "sha256:0abad1021d3f8325b2fc1d2e9c8b9c9d57b04c3932657a72465447332c24d945"}, @@ -813,13 +813,13 @@ typing-extensions = ">=4.5,<5" [[package]] name = "docugami-langchain" -version = "0.0.9rc2" +version = "0.0.9rc4" description = "An integration package connecting Docugami and LangChain" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "docugami_langchain-0.0.9rc2-py3-none-any.whl", hash = "sha256:565e17c5a81381a51dcb9101899d2baef754b0a9fc3308184e01a71720438c6e"}, - {file = "docugami_langchain-0.0.9rc2.tar.gz", hash = "sha256:3ac30ec8c813706ad36b4a3e42cfac2e5e9c6e892cc5239ff83f10bd26918132"}, + {file = "docugami_langchain-0.0.9rc4-py3-none-any.whl", hash = "sha256:4ee4814437e1628b7555195b52c5a045ba8ca38a63177f69777afe18922817b2"}, + {file = "docugami_langchain-0.0.9rc4.tar.gz", hash = "sha256:ba4d12118a0dc2c37cbb1f0b8dd3cbf4c6e50aeae570b46d4dfc249bf08d68f9"}, ] [package.dependencies] @@ -1188,6 +1188,22 @@ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4 [package.extras] grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] +[[package]] +name = "gptcache" +version = "0.1.43" +description = "GPTCache, a powerful caching library that can be used to speed up and lower the cost of chat applications that rely on the LLM service. GPTCache works as a memcache for AIGC applications, similar to how Redis works for traditional applications." +optional = false +python-versions = ">=3.8.1" +files = [ + {file = "gptcache-0.1.43-py3-none-any.whl", hash = "sha256:9c557ec9cc14428942a0ebf1c838520dc6d2be801d67bb6964807043fc2feaf5"}, + {file = "gptcache-0.1.43.tar.gz", hash = "sha256:cebe7ec5e32a3347bf839e933a34e67c7fcae620deaa7cb8c6d7d276c8686f1a"}, +] + +[package.dependencies] +cachetools = "*" +numpy = "*" +requests = "*" + [[package]] name = "greenlet" version = "3.0.3" @@ -1921,13 +1937,13 @@ extended-testing = ["lxml (>=5.1.0,<6.0.0)"] [[package]] name = "langgraph" -version = "0.0.45" +version = "0.0.46" description = "langgraph" optional = false python-versions = "<4.0,>=3.9.0" files = [ - {file = "langgraph-0.0.45-py3-none-any.whl", hash = "sha256:382df948be7c4de449cd00ff5e4ba7ee4e1523ca5107ba79cae7153ba9635ac9"}, - {file = "langgraph-0.0.45.tar.gz", hash = "sha256:596c380ea0d0f6a8fa889aa59739f76aa5ff80f49e7be35f10dd50bd44a06934"}, + {file = "langgraph-0.0.46-py3-none-any.whl", hash = "sha256:69ff85bbc18594e4607eb68f593c3f4dfc4a7b010390424437267e7c0604d8c7"}, + {file = "langgraph-0.0.46.tar.gz", hash = "sha256:5092c549bc996387ae284de2a5dbe8dee38a6be9bb551fa762c650e378777895"}, ] [package.dependencies] @@ -4628,13 +4644,13 @@ sqlcipher = ["sqlcipher3_binary"] [[package]] name = "sqlglot" -version = "23.13.7" +version = "23.14.0" description = "An easily customizable SQL parser and transpiler" optional = false python-versions = ">=3.7" files = [ - {file = "sqlglot-23.13.7-py3-none-any.whl", hash = "sha256:58c39dd236e243eedd512c7e84f3d6e9da559e2559f33c425633f813b93dcbd7"}, - {file = "sqlglot-23.13.7.tar.gz", hash = "sha256:e78b2fe345fd5187b7e895d01f51965a561d234804a0769a5a9ba6745015694b"}, + {file = "sqlglot-23.14.0-py3-none-any.whl", hash = "sha256:e413e761a74760918ff04548189aaf05224ba4699f16b17b3b8f86f3dca6f063"}, + {file = "sqlglot-23.14.0.tar.gz", hash = "sha256:d5e92fd7f45a5783e7fed0075e3c4d13d59338a69c11ffe0f4e99729352b3973"}, ] [package.extras] @@ -4725,17 +4741,18 @@ widechars = ["wcwidth"] [[package]] name = "tenacity" -version = "8.2.3" +version = "8.3.0" description = "Retry code until it succeeds" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"}, - {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"}, + {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"}, + {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"}, ] [package.extras] -doc = ["reno", "sphinx", "tornado (>=4.5)"] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "threadpoolctl" @@ -5718,4 +5735,4 @@ redis = ["redis", "redisvl"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "e80bc40bb0f9b7aabbb5cdf477fdfbe93a8e7677ccf5bd90a7323a7ddaad69f5" +content-hash = "38c472ba2c491f690e7fa9937b90a22d49480bba8829729787d11ea241054fdf" diff --git a/pyproject.toml b/pyproject.toml index 1d4231c..d009e38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ readme = "README.md" python = ">=3.9,<4.0" langchain = ">=0.1.17" langchain-community = ">=0.0.36" -docugami-langchain = "^0.0.9rc2" +docugami-langchain = "^0.0.9rc4" # docugami-langchain = {git = "https://github.com/docugami/docugami-langchain.git"} # docugami-langchain = {git = "https://github.com/docugami/docugami-langchain.git", rev = "tjaffri/rel"} typer = ">=0.9.0" @@ -21,6 +21,7 @@ langsmith = ">=0.1.52" openpyxl = ">=3.1.2" faiss-cpu = ">=1.8.0" torch = "2.2.0" +gptcache = "^0.1.43" rerankers = { extras = ["all"], version = ">=0.2.0" } langchain-openai = { version = ">=0.1.6", optional = true } langchain-fireworks = { version = ">=0.1.2", optional = true } @@ -65,7 +66,7 @@ exclude = "/temp/" [[tool.mypy.overrides]] module = [ - "torch", "langchain_openai.*", "langchain_fireworks.*", "docugami_langchain.*", "rerankers.*", + "torch", "langchain_openai.*", "langchain_fireworks.*", "docugami_langchain.*", "rerankers.*", "gptcache.*" ] ignore_missing_imports = true