Skip to content

Commit

Permalink
Handle more cases of empty string arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Jul 14, 2024
1 parent e05dc82 commit feb43f3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
3 changes: 2 additions & 1 deletion searcharray/phrase/middle_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ def __init__(self, encoded_term_posns: Union[ArrayDict, FilteredPosns], max_doc_
self.termfreq_cache : Dict[int, Tuple[np.ndarray, np.ndarray]] = {}

def memmap(self, data_dir):
self.encoded_term_posns = MemoryMappedArrays(data_dir, self.encoded_term_posns)
if self.encoded_term_posns:
self.encoded_term_posns = MemoryMappedArrays(data_dir, self.encoded_term_posns)

def warm(self):
"""Warm tf / df cache of most common terms."""
Expand Down
19 changes: 19 additions & 0 deletions test/test_search.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Test postings array search functionality."""
import numpy as np
import pandas as pd
import os
import shutil
import pytest
from searcharray.postings import SearchArray
from searcharray.similarity import bm25_similarity
from test_utils import w_scenarios


DATA_DIR = '/tmp/tmdb'


@pytest.fixture
def data():
"""Return a fixture of your data here that returns an instance of your ExtensionArray."""
Expand All @@ -33,6 +38,20 @@ def test_search_phrase_empty_str_batch_size():
assert data.score(["foo", "bar"]).sum() == 0


def test_search_phrase_empty_str_batch_size_memmap():
os.makedirs(DATA_DIR, exist_ok=True)
data = pd.DataFrame({"data": [""] * 10000})
data = SearchArray.index(data["data"],
batch_size=1000,
data_dir=DATA_DIR)
assert data.score(["foo", "bar"]).sum() == 0
pd.to_pickle(data, os.path.join(DATA_DIR, "data.pkl"))
reloaded = pd.read_pickle(os.path.join(DATA_DIR, "data.pkl"))
assert reloaded.score(["foo", "bar"]).sum() == 0

shutil.rmtree(DATA_DIR)


def test_match(data):
matches = data.termfreqs("foo") > 0
assert (matches == [True, False, False, False] * 25).all()
Expand Down

0 comments on commit feb43f3

Please sign in to comment.