Skip to content

Commit

Permalink
Add min should match
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Dec 26, 2023
1 parent d309484 commit 3f8cc38
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 5 deletions.
8 changes: 3 additions & 5 deletions searcharray/postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,20 +668,18 @@ def positions(self, token: str, key=None) -> List[np.ndarray]:
posns = self.posns.positions(term_id, key=key)
return posns

def and_query(self, tokens: List[str]) -> np.ndarray:
def and_query(self, tokens: List[str] | List[List[str]]) -> np.ndarray:
"""Return a mask on the postings array indicating which elements contain all terms."""
masks = [self.match(term) for term in tokens]
mask = np.ones(len(self), dtype=bool)
for curr_mask in masks:
mask = mask & curr_mask
return mask

def or_query(self, tokens: List[str]) -> np.ndarray:
def or_query(self, tokens: List[str] | List[List[str]], min_should_match: int = 1) -> np.ndarray:
"""Return a mask on the postings array indicating which elements contain all terms."""
masks = [self.match(term) for term in tokens]
mask = np.ones(len(self), dtype=bool)
for curr_mask in masks:
mask = mask & curr_mask
mask = np.sum(masks, axis=0) >= min_should_match
return mask

def phrase_freq(self, tokens: List[str], slop=1) -> np.ndarray:
Expand Down
121 changes: 121 additions & 0 deletions test/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Test postings array search functionality."""
import pytest
import numpy as np
from searcharray.postings import PostingsArray
from test_utils import w_scenarios


@pytest.fixture
def data():
"""Return a fixture of your data here that returns an instance of your ExtensionArray."""
return PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25)


def test_match(data):
matches = data.match("foo")
assert (matches == [True, False, False, False] * 25).all()


def test_match_missing_term(data):
matches = data.match("not_present")
assert (matches == [False, False, False, False] * 25).all()


def test_term_freqs(data):
matches = data.term_freq("bar")
assert (matches == [2, 0, 1, 0] * 25).all()


def test_doc_freq(data):
doc_freq = data.doc_freq("bar")
assert doc_freq == (2 * 25)
doc_freq = data.doc_freq("foo")
assert doc_freq == 25


def test_doc_lengths(data):
doc_lengths = data.doc_lengths()
assert doc_lengths.shape == (100,)
assert (doc_lengths == [4, 1, 2, 3] * 25).all()
assert data.avg_doc_length == 2.5


def test_bm25_matches_lucene(data):
bm25_idf = data.bm25_idf("bar")
assert bm25_idf > 0.0
bm25 = data.bm25("bar")
assert bm25.shape == (100,)
assert np.isclose(bm25, [0.37066694, 0., 0.34314217, 0.] * 25).all()


and_scenarios = {
"base": {
"docs": lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"keywords": ["foo", "bar"],
"expected": [True, False, False, False] * 25,
},
"no_match": {
"docs": lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"keywords": ["foo", "data2"],
"expected": [False, False, False, False] * 25,
},
"and_with_phrase": {
"docs": lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"keywords": [["foo", "bar"], "baz"],
"expected": [True, False, False, False] * 25,
}
}


@w_scenarios(and_scenarios)
def test_and_query(data, docs, keywords, expected):
docs = docs()
matches = data.and_query(keywords)
assert (expected == matches).all()


or_scenarios = {
"base": {
"docs": lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"keywords": ["foo", "bar"],
"expected": [True, False, True, False] * 25,
"min_should_match": 1,
},
"mm_2": {
"docs": lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"keywords": ["foo", "bar"],
"expected": [True, False, False, False] * 25,
"min_should_match": 2,
},
"one_term_match": {
"docs": lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"keywords": ["foo", "data2"],
"expected": [True, True, False, False] * 25,
"min_should_match": 1,
},
"one_term_match_mm2": {
"docs": lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"keywords": ["foo", "data2"],
"expected": [False, False, False, False] * 25,
"min_should_match": 2,
},
"or_with_phrase": {
"docs": lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"keywords": [["foo", "bar"], "baz"],
"expected": [True, False, False, False] * 25,
"min_should_match": 1,
},
"or_with_phrase_mm2": {
"docs": lambda: PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"keywords": [["foo", "bar"], ["bar", "baz"]],
"expected": [True, False, False, False] * 25,
"min_should_match": 2,
}
}


@w_scenarios(or_scenarios)
def test_or_query(data, docs, keywords, expected, min_should_match):
docs = docs()
matches = data.or_query(keywords, min_should_match=min_should_match)
assert (expected == matches).all()

0 comments on commit 3f8cc38

Please sign in to comment.