Skip to content

Commit

Permalink
Add strided intersect tests, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed May 11, 2024
1 parent 1ebb01b commit af5c587
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 17 deletions.
7 changes: 6 additions & 1 deletion searcharray/phrase/middle_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,15 @@ def doc_encoded_posns(self, term_id: int, doc_id: int) -> np.ndarray:
keys=np.asarray([doc_id], dtype=np.uint64))
return term_posns

def phrase_freqs(self, term_ids: List[int], phrase_freqs: np.ndarray,
def empty_buffer(self):
return np.zeros(self.max_doc_id + 1, dtype=np.float64)

def phrase_freqs(self, term_ids: List[int],
slop: int = 0,
doc_ids: Optional[np.ndarray] = None,
min_posn: Optional[int] = None,
max_posn: Optional[int] = None) -> np.ndarray:
phrase_freqs = self.empty_buffer()
if len(term_ids) < 2:
raise ValueError("Must have at least two terms")
if phrase_freqs.shape[0] == self.max_doc_id + 1 and min_posn is None and max_posn is None and doc_ids is None:
Expand All @@ -358,6 +362,7 @@ def phrase_freqs(self, term_ids: List[int], phrase_freqs: np.ndarray,
keys=keys,
min_payload=min_posn,
max_payload=max_posn) for term_id in term_ids]
import pdb; pdb.set_trace()

if slop == 0:
return compute_phrase_freqs(enc_term_posns, phrase_freqs, max_doc_id=np.uint64(self.max_doc_id))
Expand Down
20 changes: 12 additions & 8 deletions searcharray/postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,6 @@ def phrase_freq(self, tokens: List[str],
active_docs: Optional[np.ndarray] = None,
min_posn: Optional[int] = None,
max_posn: Optional[int] = None) -> np.ndarray:
phrase_freqs = np.zeros(len(self))
try:
# Decide how/if we need to filter doc ids
doc_ids = None
Expand All @@ -691,14 +690,19 @@ def phrase_freq(self, tokens: List[str],
doc_ids = self.term_mat.rows

term_ids = [self.term_dict.get_term_id(token) for token in tokens]
return self.posns.phrase_freqs(term_ids,
doc_ids=doc_ids,
phrase_freqs=phrase_freqs,
slop=slop,
min_posn=min_posn,
max_posn=max_posn)
except TermMissingError:
import pdb; pdb.set_trace()
phrase_freqs = self.posns.phrase_freqs(term_ids,
doc_ids=doc_ids,
slop=slop,
min_posn=min_posn,
max_posn=max_posn)
if doc_ids is not None:
return phrase_freqs[doc_ids]
return phrase_freqs
except TermMissingError:
if doc_ids is not None:
return np.zeros(len(doc_ids))
return self.posns.empty_buffer()

def phrase_freq_scan(self, tokens: List[str], mask=None, slop=0) -> np.ndarray:
if mask is None:
Expand Down
30 changes: 28 additions & 2 deletions searcharray/roaringish/snp_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ cdef void _gallop_intersect_keep(DTYPE_t* lhs,
DTYPE_t* rhs_out,
DTYPE_t* lhs_out_len,
DTYPE_t* rhs_out_len,
DTYPE_t mask=ALL_BITS) nogil:
DTYPE_t mask=ALL_BITS) noexcept:
"""Two pointer approach to find the intersection of two sorted arrays."""
cdef DTYPE_t* lhs_ptr = &lhs[0]
cdef DTYPE_t* rhs_ptr = &rhs[0]
Expand All @@ -273,6 +273,11 @@ cdef void _gallop_intersect_keep(DTYPE_t* lhs,
cdef DTYPE_t* lhs_result_ptr = &lhs_out[0]
cdef DTYPE_t* rhs_result_ptr = &rhs_out[0]

print(f"Input {lhs[0]}|{rhs[0]}")
print(f" {lhs[1]}|{rhs[1]}")
print(f" {lhs[2]}|{rhs[2]}")
print(f" {lhs[3]}|{rhs[3]}")

while lhs_ptr < end_lhs_ptr and rhs_ptr < end_rhs_ptr:
# Gallop past the current element
while lhs_ptr < end_lhs_ptr and (lhs_ptr[0] & mask) < (rhs_ptr[0] & mask):
Expand All @@ -285,6 +290,8 @@ cdef void _gallop_intersect_keep(DTYPE_t* lhs,
delta <<= 1
rhs_ptr -= (delta >> 1)
delta = 1
print(f" At: {lhs_ptr - &lhs[0]}|{rhs_ptr - &rhs[0]}")
print(f"Values: {lhs_ptr[0]} vs {rhs_ptr[0]}")

# Now that we've reset, we just do the naive 2-ptr check
# Then next loop we pickup on exponential search
Expand All @@ -294,14 +301,19 @@ cdef void _gallop_intersect_keep(DTYPE_t* lhs,
rhs_ptr += 1
else:
target = lhs_ptr[0] & mask
print("---")
print(f"Collecting target: {target}")
# Store all LHS indices equal to RHS
while (lhs_ptr[0] & mask) == target and lhs_ptr < end_lhs_ptr:
lhs_result_ptr[0] = lhs_ptr - &lhs[0]; lhs_result_ptr += 1
print(f"lhs_ptr: {lhs_ptr[0]} target: {target} -- saving idx: {lhs_ptr - &lhs[0]}")
lhs_ptr += 1
# Store all RHS equal to LHS
while (rhs_ptr[0] & mask) == target and rhs_ptr < end_rhs_ptr:
rhs_result_ptr[0] = rhs_ptr - &rhs[0]; rhs_result_ptr += 1
print(f"rhs_ptr: {rhs_ptr[0]} target: {target} -- saving idx: {rhs_ptr - &rhs[0]}")
rhs_ptr += 1
print("---")

# If delta
# Either we read past the array, or
Expand Down Expand Up @@ -409,12 +421,26 @@ def intersect(np.ndarray[DTYPE_t, ndim=1] lhs,
else:
lhs_out = np.empty(max(lhs.shape[0], rhs.shape[0]), dtype=np.uint64)
rhs_out = np.empty(max(lhs.shape[0], rhs.shape[0]), dtype=np.uint64)
print(f"INPUT... Intersect KEEP {mask:0x} {lhs_out_len}|{rhs_out_len}")
print([lhs[i] for i in range(lhs.shape[0])])
print([rhs[i] for i in range(rhs.shape[0])])
print(f"Input {lhs[0]}|{rhs[0]}")
print(f" {lhs[1]}|{rhs[1]}")
print(f" {lhs[2]}|{rhs[2]}")
print(f" {lhs[3]}|{rhs[3]}")
print(f"Strides {lhs.strides[0]}|{rhs.strides[0]}")
print(" ----")

_gallop_intersect_keep(&lhs[0], &rhs[0],
lhs.shape[0], rhs.shape[0],
&lhs_out[0], &rhs_out[0],
&lhs_out_len, &rhs_out_len,
mask)
return np.asarray(lhs_out)[:lhs_out_len], np.asarray(rhs_out)[:rhs_out_len]
print(f"OUTPUT... Intersect KEEP {mask:0x} {lhs_out_len}|{rhs_out_len}")
lhs_out, rhs_out = np.asarray(lhs_out)[:lhs_out_len], np.asarray(rhs_out)[:rhs_out_len]
print([lhs[lhs_out[i]] for i in range(lhs_out_len)])
print([rhs[rhs_out[i]] for i in range(rhs_out_len)])
return lhs_out, rhs_out


def adjacent(np.ndarray[DTYPE_t, ndim=1] lhs,
Expand Down
12 changes: 6 additions & 6 deletions searcharray/solr.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ def pf_phase(curr_scores):
if len(terms) < 2:
continue

arr = arr[curr_scores > 0]
field_phrase_score = arr.score(terms, similarity=similarity[field],
active_docs=curr_scores,
) * (1 if boost is None else boost)
boost_exp = f"{boost}" if boost is not None else "1"
explain += f" ({field}:\"{' '.join(terms)}\")^{boost_exp}"
Expand All @@ -270,13 +270,13 @@ def pf2_phase(curr_scores):
explain = ""
for field, boost in bigram_fields.items():
arr = get_field(frame, field)
arr = arr[curr_scores > 0]
terms = search_terms[field]
if len(terms) < 2:
continue
# For each bigram
for term, next_term in zip(terms, terms[1:]):
field_bigram_score = arr.score([term, next_term], similarity=similarity[field],
active_docs=curr_scores,
) * (1 if boost is None else boost)
boost_exp = f"{boost}" if boost is not None else "1"
explain += f" ({field}:\"{term} {next_term}\")^{boost_exp}"
Expand All @@ -296,9 +296,9 @@ def pf3_phase(curr_scores):
if len(terms) < 3:
continue
# For each trigram
arr = arr[curr_scores > 0]
for term, next_term, next_next_term in zip(terms, terms[1:], terms[2:]):
field_trigram_score = arr.score([term, next_term, next_next_term],
active_docs=curr_scores,
similarity=similarity[field]) * (1 if boost is None else boost)
boost_exp = f"{boost}" if boost is not None else "1"
explain += f" ({field}:\"{term} {next_term} {next_next_term}\")^{boost_exp}"
Expand All @@ -312,18 +312,18 @@ def pf3_phase(curr_scores):
phrase_scores = np.sum(phrase_scores, axis=0)
# Add where term_scores > 0
term_match_idx = np.where(qf_scores)[0]
qf_scores[term_match_idx] += phrase_scores[term_match_idx]
qf_scores[term_match_idx] += phrase_scores

if len(bigram_scores) > 0:
bigram_scores = np.sum(bigram_scores, axis=0)
# Add where term_scores > 0
term_match_idx = np.where(qf_scores)[0]
qf_scores[term_match_idx] += bigram_scores[term_match_idx]
qf_scores[term_match_idx] += bigram_scores

if len(trigram_scores) > 0:
trigram_scores = np.sum(trigram_scores, axis=0)
# Add where term_scores > 0
term_match_idx = np.where(qf_scores)[0]
qf_scores[term_match_idx] += trigram_scores[term_match_idx]
qf_scores[term_match_idx] += trigram_scores

return qf_scores, explain
13 changes: 13 additions & 0 deletions test/test_phrase_matches.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,19 @@ def test_phrase(docs, phrase, expected, algorithm):
assert (expected == phrase_matches2).all()


@w_scenarios(scenarios)
def test_phrase_on_slice(docs, phrase, expected):
docs = docs()
# Get odd docs
docs = docs[1::2]
term_freqs = docs.termfreqs(phrase)
assert len(term_freqs) == len(docs)
expected = np.array(expected)
assert (term_freqs == expected[1::2]).all()
# All other 0
assert term_freqs[0::2].sum() == 0


@w_scenarios(scenarios)
def test_phrase_active_docs(docs, phrase, expected):
docs = docs()
Expand Down
31 changes: 31 additions & 0 deletions test/test_snp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,28 @@ def test_search_masks(mask, algorithm, array: np.ndarray, target: np.uint64, exp
"mask": None,
"expected": u64([32, 42])
},
"trouble_scen": {
"lhs": u64([1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, 65, 67, 69, 71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95, 97, 99, 101, 103, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123]),

"rhs": u64([0, 4, 4, 5, 9, 9, 10, 14, 14, 15, 19, 19, 20, 24, 24, 25, 29, 29, 30, 34, 34, 35, 39, 39, 40, 44, 44, 45, 49, 49, 50, 54, 54, 55, 59, 59, 60, 64, 64, 65, 69, 69, 70, 74, 74, 75, 79, 79, 80, 84, 84, 85, 89, 89, 90, 94, 94, 95, 99, 99, 100, 104, 104, 105, 109, 109, 110, 114, 114, 115, 119, 119, 120, 124, 124]),
"mask": None,
"expected": u64([5, 9, 15, 19, 25, 29, 35, 39, 45, 49, 55, 59, 65, 69, 75, 79, 85, 89, 95, 99, 105, 109, 115, 119])
}
}


@w_scenarios(intersect_scenarios)
def test_intersect_strided(lhs, rhs, mask, expected):
if mask is None:
mask = np.uint64(0xFFFFFFFFFFFFFFFF)
lhs = lhs[::2]
rhs = rhs[::2]
expected = np.intersect1d(lhs, rhs)
lhs_idx, rhs_idx = intersect(lhs, rhs, mask=mask)
result = lhs[lhs_idx] & mask
assert np.all(result == expected)


@w_scenarios(intersect_scenarios)
def test_intersect(lhs, rhs, mask, expected):
if mask is None:
Expand All @@ -157,6 +176,18 @@ def test_intersect_keep_both(lhs, rhs, mask, expected):
assert np.all(rhs_idx == expected_rhs)


@w_scenarios(intersect_scenarios)
def test_intersect_keep_both_strided(lhs, rhs, mask, expected):
if mask is None:
mask = np.uint64(0xFFFFFFFFFFFFFFFF)
lhs = lhs[::2]
rhs = rhs[::2]
expected = np.intersect1d(lhs, rhs)
lhs_idx, rhs_idx = intersect(lhs, rhs, mask=mask, drop_duplicates=False)
result = lhs[lhs_idx] & mask
assert np.all(result == expected)


@pytest.mark.parametrize("seed", [0, 1, 2, 3, 4])
def test_same_as_numpy(seed):
np.random.seed(seed)
Expand Down

0 comments on commit af5c587

Please sign in to comment.