Skip to content

Commit

Permalink
Clean up phrase matching
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Nov 14, 2023
1 parent 8e20db0 commit 5664c05
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 36 deletions.
45 changes: 39 additions & 6 deletions searcharray/postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def and_query(self, tokenized_terms):
mask = mask & curr_mask
return mask

def phrase_match(self, tokenized_terms):
def phrase_match(self, tokenized_terms, slop=1):
"""Return a boolean numpy array indicating which elements contain the given phrase."""
# Has both terms
mask = self.and_query(tokenized_terms)
Expand All @@ -641,14 +641,47 @@ def pad_arrays(arrays, pad_value=99999999999):
# Compute positional differences
prior_term = term_posns[0]
for term in term_posns[1:]:
# doc, term, posn diff matrix
# Each row of posn_diffs is a term posn diff matrix
# Where columns are prior_term posns, rows are term posns
# This shows every possible term diff
#
# Example:
# prior_term = array([[0, 4],[0, 4])
# term = array([[1, 2, 3],[1, 2, 3]])
#
#
# posn_diffs =
#
# array([[ term[0] - prior_term[0], term[0] - prior_term[1] ],
# [ term[1] - prior_term[0], ...
# [ term[2] - prior_term[0], ...
#
# or in our example
#
# array([[ 1, -3],
# [ 2, -2],
# [ 3, -1]])
#
# We care about all locations where posn == slop (or perhaps <= slop)
# that is term is slop away from prior_term. Usually slop == 1 (ie 1 posn away)
# for normal phrase matching
#
posn_diffs = term[:, :, np.newaxis] - prior_term[:, np.newaxis, :]
# Count how many times the row term is 1 away from the col term
slop = 1
per_doc_diffs = np.sum(posn_diffs == slop, axis=1)

# For > 2 terms, we need to connect a third term by making prior_term = term
# and repeating
#
# BUT
# we only want those parts of term that are adjacent to prior_term
# before continuing, so we don't accidentally get a partial phrase
# so we need to make sure to
# Pad out any rows in 'term' where posn diff != slop
term = np.where(per_doc_diffs == 1, term, 99999999999)
# so they're not considered on subsequent iterations
term_mask = np.any(posn_diffs == 1, axis=2)
term[~term_mask] = 99999999999

# Count how many times the row term is 1 away from the col term
per_doc_diffs = np.sum(posn_diffs == slop, axis=1)

# Doc-wise sum to get a 'term freq'
bigram_freqs = np.sum(per_doc_diffs == slop, axis=1)
Expand Down
92 changes: 62 additions & 30 deletions test/test_phrase_matches.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,63 @@
from searcharray.postings import PostingsArray


def test_phrase_match():
data = PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25)
matches = data.phrase_match(["foo", "bar"])
assert (matches == [True, False, False, False] * 25).all()


def test_phrase_match_three_terms():
data = PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25)
matches = data.phrase_match(["bunny", "funny", "wunny"])
assert (matches == [False, False, False, True] * 25).all()


def test_phrase_match_three_terms_spread_out_doesnt_match():
spread_out = PostingsArray.index(["foo bar EEK foo URG bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25)
matches = spread_out.phrase_match(["foo", "bar", "baz"])
assert (matches == [False, False, False, False] * 25).all()


def test_phrase_match_same_term_matches():
spread_out = PostingsArray.index(["foo foo foo", "data2", "data3 bar", "bunny funny wunny"] * 25)
matches = spread_out.phrase_match(["foo", "foo", "foo"])
assert (matches == [True, False, False, False] * 25).all()


def test_phrase_match_duplicate_phrases():
multiple = PostingsArray.index(["foo bar foo bar", "data2", "data3 bar", "bunny funny wunny"] * 25)
matches = multiple.phrase_match(["foo", "bar"])
assert (matches == [True, False, False, False] * 25).all()
from test_utils import w_scenarios
from time import perf_counter


scenarios = {
"base": {
"docs": PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"phrase": ["foo", "bar"],
"expected": [True, False, False, False] * 25,
},
"multi_term_one_doc": {
"docs": PostingsArray.index(["foo bar bar bar foo", "data2", "data3 bar", "bunny funny wunny"] * 25),
"phrase": ["foo", "bar"],
"expected": [True, False, False, False] * 25,
},
"three_terms_match": {
"docs": PostingsArray.index(["foo bar baz baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"phrase": ["foo", "bar", "baz"],
"expected": [True, False, False, False] * 25,
},
"three_terms_no_match": {
"docs": PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"phrase": ["foo", "bar", "baz"],
"expected": [False, False, False, False] * 25,
},
"many_docs": {
"docs": PostingsArray.index(["foo bar bar baz", "data2", "data3 bar", "bunny funny wunny"] * 100000),
"phrase": ["foo", "bar"],
"expected": [True, False, False, False] * 100000,
},
"three_terms_spread_out": {
"docs": PostingsArray.index(["foo bar EEK foo URG bar baz", "data2", "data3 bar", "bunny funny wunny"] * 25),
"phrase": ["foo", "bar", "baz"],
"expected": [False, False, False, False] * 25,
},
"same_term_matches": {
"docs": PostingsArray.index(["foo foo foo", "data2", "data3 bar", "bunny funny wunny"] * 25),
"phrase": ["foo", "foo"],
"expected": [True, False, False, False] * 25,
},
"same_term_matches_3": {
"docs": PostingsArray.index(["foo foo foo", "data2", "data3 bar", "bunny funny wunny"] * 25),
"phrase": ["foo", "foo", "foo"],
"expected": [True, False, False, False] * 25,
},
"duplicate_phrases": {
"docs": PostingsArray.index(["foo bar foo bar", "data2", "data3 bar", "bunny funny wunny"] * 25),
"phrase": ["foo", "bar"],
"expected": [True, False, False, False] * 25,
},
}


@w_scenarios(scenarios)
def test_phrase(docs, phrase, expected):
start = perf_counter()
docs_before = docs.copy()
matches = docs.phrase_match(phrase)
print(f"phrase_match took {perf_counter() - start} seconds | {len(docs)} docs")
assert (matches == expected).all()
if len(docs) < 1000:
assert (docs == docs_before).all(), "The phrase_match method should not modify the original array"

0 comments on commit 5664c05

Please sign in to comment.