Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caching opened tantivy.Indexes in the package #627

Merged
merged 8 commits into from
Oct 23, 2024
78 changes: 62 additions & 16 deletions paperqa/agents/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import zlib
from collections.abc import Awaitable, Callable, Collection, Sequence
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, ClassVar, cast
from typing import TYPE_CHECKING, Any, ClassVar
from uuid import UUID

import anyio
Expand Down Expand Up @@ -104,6 +104,25 @@ def read_from_string(self, data: str | bytes) -> BaseModel | SupportsPickle:
return pickle.loads(data) # type: ignore[arg-type] # noqa: S301


ENV_VAR_MATCH: Collection[str] = {"1", "true"}

# Cache keys are a two-tuple of index name and absolute index directory
# Cache values are a two-tuple of an opened Index instance and the count
# of SearchIndex instances currently referencing that Index
_OPENED_INDEX_CACHE: dict[tuple[str, str], tuple[Index, int]] = {}
DONT_USE_OPENED_INDEX_CACHE = (
os.environ.get("PQA_INDEX_DONT_CACHE_INDEXES", "").lower() in ENV_VAR_MATCH
)


def reap_opened_index_cache() -> None:
"""Delete any unreferenced Index instances from the Index cache."""
for index_name, (index, count) in _OPENED_INDEX_CACHE.items():
if count == 0:
_OPENED_INDEX_CACHE.pop(index_name)
del index


class SearchIndex:
"""Wrapper around a tantivy.Index exposing higher-level behaviors for documents."""

Expand All @@ -127,21 +146,25 @@ def __init__(
)
self.index_name = index_name
self._index_directory = index_directory
self._schema = None
self._index = None
self._searcher = None
self._schema: Schema | None = None
self._index: Index | None = None
self._searcher: Searcher | None = None
self._index_files: dict[str, str] = {}
self.changed = False
self.storage = storage

@property
async def index_directory(self) -> anyio.Path:
async def index_directory( # TODO: rename to index_root_directory
self,
) -> anyio.Path:
directory = anyio.Path(self._index_directory).joinpath(self.index_name)
await directory.mkdir(parents=True, exist_ok=True)
return directory

@property
async def index_filename(self) -> anyio.Path: # TODO: rename to index_directory
async def index_filename( # TODO: rename to index_meta_directory
self,
) -> anyio.Path:
"""Directory to store files used to house index internals."""
index_dir = (await self.index_directory) / "index"
await index_dir.mkdir(exist_ok=True)
Expand All @@ -165,27 +188,51 @@ def schema(self) -> Schema:
schema_builder = SchemaBuilder()
for field in self.fields:
schema_builder.add_text_field(field, stored=True)
self._schema = schema_builder.build() # type: ignore[assignment]
return cast(Schema, self._schema)
self._schema = schema_builder.build()
return self._schema

@property
async def index(self) -> Index:
if not self._index:
index_path = await self.index_filename
if await (index_path / "meta.json").exists():
self._index = Index.open(path=str(index_path)) # type: ignore[assignment]
index_meta_directory = await self.index_filename
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can hit a race condition here (because the exists check is awaitable). We should prob. acquire a lock and then do the check or do the check synchronously.

Otherwise I think a "gather" call could result in several False responses to this exists call.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok we talked about this, because we are only instantiating and Index and not opening one, this likely isn't an issue. @jamesbraza is going to add a comment for future reference.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep thanks for the discussion! I added a comment to the code documenting some of our talking points

if await (index_meta_directory / "meta.json").exists():
if DONT_USE_OPENED_INDEX_CACHE:
self._index = Index.open(path=str(index_meta_directory))
else:
key = self.index_name, str(await index_meta_directory.absolute())
# NOTE: now we know we're using the cache and have created the cache
# key. And we know we're in asyncio.gather race condition risk land.
# All of the following operations are *synchronous* so we are not
# giving the opportunity for an await to switch to another parallel
# version of this code. Otherwise, we risk counts being incorrect
# due to race conditions
if key not in _OPENED_INDEX_CACHE: # open a new Index
self._index = Index.open(path=str(index_meta_directory))
prev_count: int = 0
else: # reuse Index
self._index, prev_count = _OPENED_INDEX_CACHE[key]
_OPENED_INDEX_CACHE[key] = self._index, prev_count + 1
else:
# NOTE: this creates the above meta.json file
self._index = Index(self.schema, path=str(index_path)) # type: ignore[assignment]
return cast(Index, self._index)
self._index = Index(self.schema, path=str(index_meta_directory))
return self._index

def __del__(self) -> None:
index_meta_directory = (
pathlib.Path(self._index_directory) / self.index_name / "index"
)
key = self.index_name, str(index_meta_directory.absolute())
if key in _OPENED_INDEX_CACHE:
index, count = _OPENED_INDEX_CACHE[key]
_OPENED_INDEX_CACHE[key] = index, count - 1

@property
async def searcher(self) -> Searcher:
if not self._searcher:
index = await self.index
index.reload()
self._searcher = index.searcher() # type: ignore[assignment]
return cast(Searcher, self._searcher)
self._searcher = index.searcher()
return self._searcher

@property
async def count(self) -> int:
Expand Down Expand Up @@ -484,7 +531,6 @@ async def process_file(


WARN_IF_INDEXING_MORE_THAN = 999
ENV_VAR_MATCH: Collection[str] = {"1", "true"}


def _make_progress_bar_update(
Expand Down
43 changes: 34 additions & 9 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ldp.llms import EmbeddingModel, MultipleCompletionLLMModel
from pydantic import ValidationError
from pytest_subtests import SubTests
from tantivy import Index

from paperqa.agents import SearchIndex, agent_query
from paperqa.agents.env import settings_to_tools
Expand Down Expand Up @@ -240,7 +241,13 @@ async def test_agent_types(
" accept the answer for now, as we're in debug mode."
)
request = QueryRequest(query=question, settings=agent_test_settings)
response = await agent_query(request, agent_type=agent_type)
with patch.object(
Index, "open", side_effect=Index.open, autospec=True
) as mock_open:
response = await agent_query(request, agent_type=agent_type)
assert (
mock_open.call_count <= 1
), "Expected one Index.open call, or possibly zero if multiprocessing tests"
assert response.answer.answer, "Answer not generated"
assert response.answer.answer != "I cannot answer", "Answer not generated"
assert response.answer.context, "No contexts were found"
Expand Down Expand Up @@ -463,20 +470,38 @@ async def test_agent_sharing_state(
search_tool = PaperSearch(
settings=agent_test_settings, embedding_model=embedding_model
)
with patch.object(
SearchIndex, "save_index", autospec=True, wraps=SearchIndex.save_index
) as mock_save_index:
with (
patch.object(
SearchIndex, "save_index", wraps=SearchIndex.save_index, autospec=True
) as mock_save_index,
patch.object(
Index, "open", side_effect=Index.open, autospec=True
) as mock_open,
):
await search_tool.paper_search(
"XAI self explanatory model",
min_year=None,
max_year=None,
state=env_state,
)
assert env_state.docs.docs, "Search did not add any papers"
mock_save_index.assert_not_awaited(), "Search shouldn't try to update the index"
assert all(
isinstance(d, Doc) for d in env_state.docs.docs.values()
), "Document type or DOI propagation failure"
assert env_state.docs.docs, "Search did not add any papers"
assert (
mock_open.call_count <= 1
), "Expected one Index.open call, or possibly zero if multiprocessing tests"
assert all(
isinstance(d, Doc) for d in env_state.docs.docs.values()
), "Document type or DOI propagation failure"

await search_tool.paper_search(
"XAI for chemical property prediction",
min_year=2018,
max_year=2024,
state=env_state,
)
assert (
mock_open.call_count <= 1
), "Expected one Index.open call, or possibly zero if multiprocessing tests"
mock_save_index.assert_not_awaited()

with subtests.test(msg=GatherEvidence.__name__):
assert not answer.contexts, "No contexts is required for a later assertion"
Expand Down