Skip to content

Commit

Permalink
Add more scattered posns tests
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed May 11, 2024
1 parent 0e47c3f commit 009dc16
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions test/test_phrase_matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,60 @@ def test_phrase_scattered_posns(posn_offset):
assert (expected == phrase_matches).all()


@pytest.mark.parametrize("posn_offset", range(100))
def test_phrase_scattered_posns_sliced(posn_offset):
scattered = "foo bar " + " ".join(["dummy"] * posn_offset) + " foo bar baz"
docs = SearchArray.index([scattered,
"not match"] * 1000)
docs = docs[::2]
phrase = ["foo", "bar"]
expected = [2] * 1000
phrase_matches = docs.phrase_freq(phrase)
assert (expected == phrase_matches).all()


@pytest.mark.parametrize("posn_offset", range(100))
def test_phrase_scattered_posns_sliced_one_term_rpt(posn_offset):
scattered = "foo bar " + " ".join(["foo"] * posn_offset) + " foo bar baz"
docs = SearchArray.index([scattered,
"not match"] * 1000)
docs = docs[::2]
phrase = ["foo", "bar"]
expected = [2] * 1000
phrase_matches = docs.phrase_freq(phrase)
assert (expected == phrase_matches).all()


@pytest.mark.parametrize("posn_offset", range(100))
def test_phrase_scattered_posns_sliced_frequent(posn_offset):
scattered = "foo bar " + " ".join(["foo"] * posn_offset) + " foo bar baz"
idx = [scattered,
"foo",
"foo"] * 1000
docs = SearchArray.index(idx)
docs = docs[::2]
idx = np.array(idx)[::2]
phrase = ["foo", "bar"]
expected = [2 if "foo bar" in doc else 0 for doc in idx]
phrase_matches = docs.phrase_freq(phrase)
assert (expected == phrase_matches).all()


@pytest.mark.parametrize("posn_offset", range(100))
def test_phrase_scattered_posns_sliced_frequent_long(posn_offset):
scattered = "foo bar baz " + " ".join(["foo"] * posn_offset) + " foo bar baz"
idx = [scattered,
"foo baz",
"foo"] * 1000
docs = SearchArray.index(idx)
docs = docs[::2]
idx = np.array(idx)[::2]
phrase = ["foo", "bar", "baz"]
expected = [2 if " ".join(phrase) in doc else 0 for doc in idx]
phrase_matches = docs.phrase_freq(phrase)
assert (expected == phrase_matches).all()


@pytest.mark.parametrize("posn_offset", range(100))
def test_phrase_scattered_posns3(posn_offset):
scattered = "foo bar baz " + " ".join(["dummy"] * posn_offset) + " foo bar baz"
Expand Down

0 comments on commit 009dc16

Please sign in to comment.