Skip to content

Commit

Permalink
Fix adjacent merging when msb wraps
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Dec 25, 2023
1 parent 0ef7cf0 commit de19f99
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
29 changes: 21 additions & 8 deletions searcharray/phrase/middle_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,32 @@ def inner_bigram_freqs(lhs: np.ndarray, rhs: np.ndarray,
rhs_next: the next rhs array to continue matching
"""
lhs, rhs = encoder.intersect(lhs, rhs)
lhs_doc_ids = encoder.keys(lhs)
if len(lhs) != len(rhs):
lhs_int, rhs_int = encoder.intersect(lhs, rhs)
lhs_doc_ids = encoder.keys(lhs_int)
if len(lhs_int) != len(rhs_int):
raise ValueError("Encoding error, MSBs apparently are duplicated among your encoded posn arrays.")
rhs_next = (rhs & encoder.key_mask)
if len(lhs_int) == 0:
return phrase_freqs, rhs_int
same_term = (len(lhs_int) == len(rhs_int) and lhs_int[0] == rhs_int[0])
if same_term:
# Find adjacent matches
rhs_shift = rhs_int << 1
overlap = lhs_int & rhs_shift
overlap = encoder.payload_lsb(overlap)
adjacents = bit_count64(overlap).astype(np.int64)
adjacents -= -np.floor_divide(adjacents, -2) # ceiling divide
phrase_freqs[lhs_doc_ids] += adjacents
return phrase_freqs, rhs_int

rhs_next = (rhs_int & encoder.key_mask)
# With popcount soon to be in numpy, this could potentially
# be simply a left shift of the RHS LSB poppcount, and and a popcount
# to count the overlaps
for bit in range(0, encoder.payload_lsb_bits - 1):
lhs_mask = 1 << bit
rhs_mask = 1 << (bit + 1)
lhs_set = (lhs & lhs_mask) != 0
rhs_set = (rhs & rhs_mask) != 0
lhs_set = (lhs_int & lhs_mask) != 0
rhs_set = (rhs_int & rhs_mask) != 0

matches = lhs_set & rhs_set
rhs_next[matches] |= rhs_mask
Expand All @@ -75,7 +88,7 @@ def adjacent_bigram_freqs(lhs: np.ndarray, rhs: np.ndarray,
lhs_doc_ids = encoder.keys(lhs_int)
# lhs lsb set and rhs lsb's most significant bit set
upper_bit = 1 << (encoder.payload_lsb_bits - 1)
matches = ((lhs_int & upper_bit) & ((rhs_int & 1) != 0))
matches = ((lhs_int & upper_bit) != 0) & ((rhs_int & 1) != 0)
phrase_freqs[lhs_doc_ids[matches]] += 1
return phrase_freqs

Expand Down Expand Up @@ -133,7 +146,7 @@ def add_posns(self, doc_id: int, term_id: int, posns):
def ensure_capacity(self, doc_id):
self.max_doc_id = max(self.max_doc_id, doc_id)

def build(self, check=True):
def build(self, check=False):
encoded_term_posns = {}
for term_id, posns in self.term_posns.items():
if len(posns) == 0:
Expand Down
2 changes: 1 addition & 1 deletion searcharray/postings.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def and_query(self, tokens: List[str]) -> np.ndarray:
return mask

def phrase_freq(self, tokens: List[str], slop=1) -> np.ndarray:
if slop == 1:
if slop == 1 and len(tokens) == len(set(tokens)):
phrase_freqs = np.zeros(len(self))
try:
term_ids = [self.term_dict.get_term_id(token) for token in tokens]
Expand Down
12 changes: 10 additions & 2 deletions searcharray/utils/roaringish.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,18 @@ def intersect(self, lhs: np.ndarray, rhs: np.ndarray, rshift=0) -> Tuple[np.ndar
rhs : np.ndarray of uint64 (encoded) values
rshift : int - right shift rhs by this many bits before intersecting (ie to find adjacent)
"""
rhs_int = rhs
if rshift >= 0:
rhs_shifted = (rhs_int >> self.payload_lsb_bits) + np.int64(rshift)
else:
rhs_int = rhs[self.payload_msb(rhs) >= np.abs(rshift)]
rhs_shifted = (rhs_int >> self.payload_lsb_bits) + np.int64(rshift).astype(np.uint64)

assert np.all(np.diff(rhs_shifted) >= 0), "not sorted"
_, (lhs_idx, rhs_idx) = snp.intersect(lhs >> self.payload_lsb_bits,
((rhs >> self.payload_lsb_bits) + np.int64(rshift)).astype(np.int64).astype(np.uint64),
rhs_shifted.astype(np.int64).astype(np.uint64),
indices=True)
return lhs[lhs_idx], rhs[rhs_idx]
return lhs[lhs_idx], rhs_int[rhs_idx]

def slice(self, encoded: np.ndarray, keys: np.ndarray) -> np.ndarray:
"""Get list of encoded that have values in keys."""
Expand Down

0 comments on commit de19f99

Please sign in to comment.