diff --git a/searcharray/roaringish/spans.pyx b/searcharray/roaringish/spans.pyx index 42e1f13..871b034 100644 --- a/searcharray/roaringish/spans.pyx +++ b/searcharray/roaringish/spans.pyx @@ -67,6 +67,26 @@ cdef DTYPE_t _consume_lsb(DTYPE_t* term): cdef DTYPE_t _posn_mask(DTYPE_t set_idx, DTYPE_t payload_base): return 1 << ((set_idx + payload_base) % 64) + +cdef DTYPE_t _num_terms(ActiveSpans* spans, DTYPE_t span_idx): + return __builtin_popcountll(spans.terms[span_idx]) + + +cdef DTYPE_t _num_posns(ActiveSpans* spans, DTYPE_t span_idx): + return __builtin_popcountll(spans.posns[span_idx]) + + +cdef bint _do_spans_overlap(ActiveSpans* spans, DTYPE_t span_idx_lhs, DTYPE_t span_idx_rhs): + return (spans.beg[span_idx_lhs] <= spans.end[span_idx_rhs]) and (spans.end[span_idx_lhs] >= spans.beg[span_idx_rhs]) + + +cdef bint _is_span_complete(ActiveSpans* spans, DTYPE_t span_idx, DTYPE_t num_terms): + cdef DTYPE_t num_terms_visited = _num_terms(spans, span_idx) + cdef DTYPE_t num_posns_visited = _num_posns(spans, span_idx) + return (num_terms_visited == num_terms) or (num_posns_visited == num_terms) + + + cdef _span_freqs(DTYPE_t[:] posns, # Flattened all terms in one array DTYPE_t[:] lengths, double[:] phrase_freqs, @@ -106,8 +126,8 @@ cdef _span_freqs(DTYPE_t[:] posns, # Flattened all terms in one array term = posns[curr_idx[term_ord]] & payload_mask curr_term_mask = 0x1 << term_ord + # Consume every position into every possible span while term != 0: - # Consume into span set_idx = _consume_lsb(&term) posn_mask = _posn_mask(set_idx, payload_base) @@ -119,18 +139,17 @@ cdef _span_freqs(DTYPE_t[:] posns, # Flattened all terms in one array # Remove spans that are too long for span_idx in range(spans.cursor): # Continue active spans - num_terms_visited = __builtin_popcountll(spans.terms[span_idx]) - num_posns_visited = __builtin_popcountll(spans.posns[span_idx]) + num_terms_visited = _num_terms(&spans, span_idx) + num_posns_visited = _num_posns(&spans, span_idx) if num_terms_visited < num_terms and num_posns_visited == num_terms: continue spans.terms[span_idx] |= curr_term_mask - num_terms_visited_now = __builtin_popcountll(spans.terms[span_idx]) + num_terms_visited_now = _num_terms(&spans, span_idx) if num_terms_visited_now > num_terms_visited: # Add position for new unique term - num_unique_posns = __builtin_popcountll(spans.posns[span_idx]) spans.posns[span_idx] |= posn_mask - new_unique_posns = __builtin_popcountll(spans.posns[span_idx]) - if num_unique_posns == new_unique_posns: + new_unique_posns = _num_posns(&spans, span_idx) + if num_posns_visited == new_unique_posns: # Clear curr_term_mask and cancel this position, we've seen it before spans.terms[span_idx] &= ~curr_term_mask continue @@ -156,10 +175,13 @@ cdef _span_freqs(DTYPE_t[:] posns, # Flattened all terms in one array # Count phrase freqs for span_idx in range(spans.cursor): - num_terms_visited = __builtin_popcountll(spans.terms[span_idx]) - num_posns_visited = __builtin_popcountll(spans.posns[span_idx]) - if num_terms_visited < num_terms or num_posns_visited < num_terms: + if not _is_span_complete(&spans, span_idx, num_terms): continue + for other_span_idx in range(spans.cursor): + if other_span_idx == span_idx or not _is_span_complete(&spans, other_span_idx, num_terms): + continue + if _do_spans_overlap(&spans, span_idx, other_span_idx): + break assert last_key < phrase_freqs.shape[0] phrase_freqs[last_key] += 1