diff --git a/searcharray/indexing.py b/searcharray/indexing.py index 5042927..1f78c35 100644 --- a/searcharray/indexing.py +++ b/searcharray/indexing.py @@ -46,7 +46,7 @@ def _compute_doc_lens(posns: np.ndarray, doc_ids: np.ndarray, num_docs: int) -> non_empty_doc_ids = doc_ids[non_empty_idxs] non_empty_doc_lens = non_empty_doc_lens[non_empty_idxs] doc_lens[non_empty_doc_ids] = non_empty_doc_lens - if doc_ids[-1] not in non_empty_doc_ids: + if len(doc_ids) > 0 and doc_ids[-1] not in non_empty_doc_ids: doc_lens[doc_ids[-1]] = posns[-1] + 1 return doc_lens diff --git a/searcharray/phrase/middle_out.py b/searcharray/phrase/middle_out.py index 5ea6fd0..434575b 100644 --- a/searcharray/phrase/middle_out.py +++ b/searcharray/phrase/middle_out.py @@ -170,6 +170,8 @@ def build(self): encoded, enc_term_boundaries = encoder.encode(keys=self.flat_array[1].view(np.uint64), boundaries=term_boundaries[:-1], payload=self.flat_array[2].view(np.uint64)) + if len(encoded) == 0: + return PosnBitArray({}, self.max_doc_id) term_ids = self.flat_array[0][term_boundaries[:-1]] encoded_term_posns = ArrayDict.from_array_with_boundaries(encoded, diff --git a/searcharray/similarity.py b/searcharray/similarity.py index 7f85b5d..d096df9 100644 --- a/searcharray/similarity.py +++ b/searcharray/similarity.py @@ -43,7 +43,10 @@ def compute_idf(num_docs, dfs): def compute_adj_doc_lens(doc_lens, avg_doc_lens, k1, b): - adj_doc_lens = doc_lens / avg_doc_lens + if avg_doc_lens == 0: + adj_doc_lens = np.zeros_like(doc_lens, dtype=np.float32) + else: + adj_doc_lens = doc_lens / avg_doc_lens adj_doc_lens *= b adj_doc_lens += 1 - b adj_doc_lens *= k1 diff --git a/test/test_search.py b/test/test_search.py index c1340a3..6786729 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -1,5 +1,6 @@ """Test postings array search functionality.""" import numpy as np +import pandas as pd import pytest from searcharray.postings import SearchArray from searcharray.similarity import bm25_similarity @@ -12,6 +13,16 @@ def data(): return SearchArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25) +@pytest.fixture +def all_empty_str(): + return pd.DataFrame({"data": [""] * 100}) + + +def test_search_empty_str(all_empty_str): + data = SearchArray.index(all_empty_str["data"]) + assert data.score("foo").sum() == 0 + + def test_match(data): matches = data.termfreqs("foo") > 0 assert (matches == [True, False, False, False] * 25).all()