From de19f99a2878fad63738c496937e111b309c28a1 Mon Sep 17 00:00:00 2001 From: Doug Turnbull Date: Sun, 24 Dec 2023 23:09:35 -0500 Subject: [PATCH] Fix adjacent merging when msb wraps --- searcharray/phrase/middle_out.py | 29 +++++++++++++++++++++-------- searcharray/postings.py | 2 +- searcharray/utils/roaringish.py | 12 ++++++++++-- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/searcharray/phrase/middle_out.py b/searcharray/phrase/middle_out.py index 9a0d546..9371308 100644 --- a/searcharray/phrase/middle_out.py +++ b/searcharray/phrase/middle_out.py @@ -41,19 +41,32 @@ def inner_bigram_freqs(lhs: np.ndarray, rhs: np.ndarray, rhs_next: the next rhs array to continue matching """ - lhs, rhs = encoder.intersect(lhs, rhs) - lhs_doc_ids = encoder.keys(lhs) - if len(lhs) != len(rhs): + lhs_int, rhs_int = encoder.intersect(lhs, rhs) + lhs_doc_ids = encoder.keys(lhs_int) + if len(lhs_int) != len(rhs_int): raise ValueError("Encoding error, MSBs apparently are duplicated among your encoded posn arrays.") - rhs_next = (rhs & encoder.key_mask) + if len(lhs_int) == 0: + return phrase_freqs, rhs_int + same_term = (len(lhs_int) == len(rhs_int) and lhs_int[0] == rhs_int[0]) + if same_term: + # Find adjacent matches + rhs_shift = rhs_int << 1 + overlap = lhs_int & rhs_shift + overlap = encoder.payload_lsb(overlap) + adjacents = bit_count64(overlap).astype(np.int64) + adjacents -= -np.floor_divide(adjacents, -2) # ceiling divide + phrase_freqs[lhs_doc_ids] += adjacents + return phrase_freqs, rhs_int + + rhs_next = (rhs_int & encoder.key_mask) # With popcount soon to be in numpy, this could potentially # be simply a left shift of the RHS LSB poppcount, and and a popcount # to count the overlaps for bit in range(0, encoder.payload_lsb_bits - 1): lhs_mask = 1 << bit rhs_mask = 1 << (bit + 1) - lhs_set = (lhs & lhs_mask) != 0 - rhs_set = (rhs & rhs_mask) != 0 + lhs_set = (lhs_int & lhs_mask) != 0 + rhs_set = (rhs_int & rhs_mask) != 0 matches = lhs_set & rhs_set rhs_next[matches] |= rhs_mask @@ -75,7 +88,7 @@ def adjacent_bigram_freqs(lhs: np.ndarray, rhs: np.ndarray, lhs_doc_ids = encoder.keys(lhs_int) # lhs lsb set and rhs lsb's most significant bit set upper_bit = 1 << (encoder.payload_lsb_bits - 1) - matches = ((lhs_int & upper_bit) & ((rhs_int & 1) != 0)) + matches = ((lhs_int & upper_bit) != 0) & ((rhs_int & 1) != 0) phrase_freqs[lhs_doc_ids[matches]] += 1 return phrase_freqs @@ -133,7 +146,7 @@ def add_posns(self, doc_id: int, term_id: int, posns): def ensure_capacity(self, doc_id): self.max_doc_id = max(self.max_doc_id, doc_id) - def build(self, check=True): + def build(self, check=False): encoded_term_posns = {} for term_id, posns in self.term_posns.items(): if len(posns) == 0: diff --git a/searcharray/postings.py b/searcharray/postings.py index 2194412..1d77c76 100644 --- a/searcharray/postings.py +++ b/searcharray/postings.py @@ -676,7 +676,7 @@ def and_query(self, tokens: List[str]) -> np.ndarray: return mask def phrase_freq(self, tokens: List[str], slop=1) -> np.ndarray: - if slop == 1: + if slop == 1 and len(tokens) == len(set(tokens)): phrase_freqs = np.zeros(len(self)) try: term_ids = [self.term_dict.get_term_id(token) for token in tokens] diff --git a/searcharray/utils/roaringish.py b/searcharray/utils/roaringish.py index 9d87d3c..d18a9c7 100644 --- a/searcharray/utils/roaringish.py +++ b/searcharray/utils/roaringish.py @@ -136,10 +136,18 @@ def intersect(self, lhs: np.ndarray, rhs: np.ndarray, rshift=0) -> Tuple[np.ndar rhs : np.ndarray of uint64 (encoded) values rshift : int - right shift rhs by this many bits before intersecting (ie to find adjacent) """ + rhs_int = rhs + if rshift >= 0: + rhs_shifted = (rhs_int >> self.payload_lsb_bits) + np.int64(rshift) + else: + rhs_int = rhs[self.payload_msb(rhs) >= np.abs(rshift)] + rhs_shifted = (rhs_int >> self.payload_lsb_bits) + np.int64(rshift).astype(np.uint64) + + assert np.all(np.diff(rhs_shifted) >= 0), "not sorted" _, (lhs_idx, rhs_idx) = snp.intersect(lhs >> self.payload_lsb_bits, - ((rhs >> self.payload_lsb_bits) + np.int64(rshift)).astype(np.int64).astype(np.uint64), + rhs_shifted.astype(np.int64).astype(np.uint64), indices=True) - return lhs[lhs_idx], rhs[rhs_idx] + return lhs[lhs_idx], rhs_int[rhs_idx] def slice(self, encoded: np.ndarray, keys: np.ndarray) -> np.ndarray: """Get list of encoded that have values in keys."""