From 39d1a301d6aad402b0aa93aefc69b300aacddfae Mon Sep 17 00:00:00 2001 From: Doug Turnbull Date: Fri, 10 May 2024 21:24:36 -0400 Subject: [PATCH] Add stride to intercept functions --- searcharray/roaringish/snp_ops.pyx | 88 +++++++++++++----------------- 1 file changed, 37 insertions(+), 51 deletions(-) diff --git a/searcharray/roaringish/snp_ops.pyx b/searcharray/roaringish/snp_ops.pyx index c1b0eb7..ded8f5a 100644 --- a/searcharray/roaringish/snp_ops.pyx +++ b/searcharray/roaringish/snp_ops.pyx @@ -201,13 +201,13 @@ def galloping_search(np.ndarray[DTYPE_t, ndim=1] array, cdef DTYPE_t _gallop_intersect_drop(DTYPE_t* lhs, DTYPE_t* rhs, + DTYPE_t lhs_stride, + DTYPE_t rhs_stride, DTYPE_t lhs_len, DTYPE_t rhs_len, DTYPE_t* lhs_out, DTYPE_t* rhs_out, - DTYPE_t mask=ALL_BITS, - DTYPE_t lhs_base=0, - DTYPE_t rhs_base=0) nogil: + DTYPE_t mask=ALL_BITS) nogil: """Two pointer approach to find the intersection of two sorted arrays.""" cdef DTYPE_t* lhs_ptr = &lhs[0] cdef DTYPE_t* rhs_ptr = &rhs[0] @@ -222,45 +222,47 @@ cdef DTYPE_t _gallop_intersect_drop(DTYPE_t* lhs, # Gallop past the current element while lhs_ptr < end_lhs_ptr and (lhs_ptr[0] & mask) < (rhs_ptr[0] & mask): - lhs_ptr+=delta + lhs_ptr+= (delta * lhs_stride) delta *= 2 - lhs_ptr -= (delta // 2) + lhs_ptr -= ((delta // 2) * lhs_stride) delta = 1 while rhs_ptr < end_rhs_ptr and (rhs_ptr[0] & mask) < (lhs_ptr[0] & mask): - rhs_ptr+=delta + rhs_ptr+= (delta * rhs_stride) delta *= 2 - rhs_ptr -= (delta // 2) + rhs_ptr -= ((delta // 2) * rhs_stride) delta = 1 # Now that we've reset, we just do the naive 2-ptr check # Then next loop we pickup on exponential search if (lhs_ptr[0] & mask) < (rhs_ptr[0] & mask): - lhs_ptr = lhs_ptr + 1 + lhs_ptr = lhs_ptr + lhs_stride elif (rhs_ptr[0] & mask) < (lhs_ptr[0] & mask): - rhs_ptr = rhs_ptr + 1 + rhs_ptr = rhs_ptr + rhs_stride else: # If here values equal, collect if (last & mask) != (lhs_ptr[0] & mask): - lhs_result_ptr[0] = (lhs_ptr - &lhs[0] + lhs_base) - rhs_result_ptr[0] = (rhs_ptr - &rhs[0] + rhs_base) + lhs_result_ptr[0] = (lhs_ptr - &lhs[0]) / lhs_stride + rhs_result_ptr[0] = (rhs_ptr - &rhs[0]) / rhs_stride last = lhs_ptr[0] lhs_result_ptr += 1 rhs_result_ptr += 1 - lhs_ptr += 1 - rhs_ptr += 1 + lhs_ptr += lhs_stride + rhs_ptr += rhs_stride return lhs_result_ptr - &lhs_out[0] cdef void _gallop_intersect_keep(DTYPE_t* lhs, DTYPE_t* rhs, + DTYPE_t lhs_stride, + DTYPE_t rhs_stride, DTYPE_t lhs_len, DTYPE_t rhs_len, DTYPE_t* lhs_out, DTYPE_t* rhs_out, DTYPE_t* lhs_out_len, DTYPE_t* rhs_out_len, - DTYPE_t mask=ALL_BITS) noexcept: + DTYPE_t mask=ALL_BITS) noexcept nogil: """Two pointer approach to find the intersection of two sorted arrays.""" cdef DTYPE_t* lhs_ptr = &lhs[0] cdef DTYPE_t* rhs_ptr = &rhs[0] @@ -273,47 +275,37 @@ cdef void _gallop_intersect_keep(DTYPE_t* lhs, cdef DTYPE_t* lhs_result_ptr = &lhs_out[0] cdef DTYPE_t* rhs_result_ptr = &rhs_out[0] - print(f"Input {lhs[0]}|{rhs[0]}") - print(f" {lhs[1]}|{rhs[1]}") - print(f" {lhs[2]}|{rhs[2]}") - print(f" {lhs[3]}|{rhs[3]}") - - while lhs_ptr < end_lhs_ptr and rhs_ptr < end_rhs_ptr: + while lhs_ptr < end_lhs_ptr or rhs_ptr < end_rhs_ptr: # Gallop past the current element while lhs_ptr < end_lhs_ptr and (lhs_ptr[0] & mask) < (rhs_ptr[0] & mask): - lhs_ptr += delta + lhs_ptr += (delta * lhs_stride) delta <<= 1 - lhs_ptr -= (delta >> 1) + lhs_ptr -= ((delta >> 1) * lhs_stride) delta = 1 while rhs_ptr < end_rhs_ptr and (rhs_ptr[0] & mask) < (lhs_ptr[0] & mask): - rhs_ptr += delta + rhs_ptr += (delta * rhs_stride) delta <<= 1 - rhs_ptr -= (delta >> 1) + rhs_ptr -= ((delta >> 1) * rhs_stride) delta = 1 - print(f" At: {lhs_ptr - &lhs[0]}|{rhs_ptr - &rhs[0]}") - print(f"Values: {lhs_ptr[0]} vs {rhs_ptr[0]}") # Now that we've reset, we just do the naive 2-ptr check # Then next loop we pickup on exponential search if (lhs_ptr[0] & mask) < (rhs_ptr[0] & mask): - lhs_ptr += 1 + lhs_ptr += lhs_stride elif (rhs_ptr[0] & mask) < (lhs_ptr[0] & mask): - rhs_ptr += 1 + rhs_ptr += rhs_stride else: target = lhs_ptr[0] & mask - print("---") - print(f"Collecting target: {target}") # Store all LHS indices equal to RHS while (lhs_ptr[0] & mask) == target and lhs_ptr < end_lhs_ptr: - lhs_result_ptr[0] = lhs_ptr - &lhs[0]; lhs_result_ptr += 1 - print(f"lhs_ptr: {lhs_ptr[0]} target: {target} -- saving idx: {lhs_ptr - &lhs[0]}") - lhs_ptr += 1 + lhs_result_ptr[0] = (lhs_ptr - &lhs[0]) / lhs_stride + lhs_result_ptr += 1 + lhs_ptr += lhs_stride # Store all RHS equal to LHS while (rhs_ptr[0] & mask) == target and rhs_ptr < end_rhs_ptr: - rhs_result_ptr[0] = rhs_ptr - &rhs[0]; rhs_result_ptr += 1 - print(f"rhs_ptr: {rhs_ptr[0]} target: {target} -- saving idx: {rhs_ptr - &rhs[0]}") - rhs_ptr += 1 - print("---") + rhs_result_ptr[0] = (rhs_ptr - &rhs[0]) / rhs_stride + rhs_result_ptr += 1 + rhs_ptr += rhs_stride # If delta # Either we read past the array, or @@ -413,7 +405,10 @@ def intersect(np.ndarray[DTYPE_t, ndim=1] lhs, lhs_out = np.empty(min(lhs.shape[0], rhs.shape[0]), dtype=np.uint64) rhs_out = np.empty(min(lhs.shape[0], rhs.shape[0]), dtype=np.uint64) amt_written = _gallop_intersect_drop(&lhs[0], &rhs[0], - lhs.shape[0], rhs.shape[0], + lhs.strides[0] / sizeof(DTYPE_t), + rhs.strides[0] / sizeof(DTYPE_t), + lhs.shape[0] * lhs.strides[0] / sizeof(DTYPE_t), + rhs.shape[0] * rhs.strides[0] / sizeof(DTYPE_t), &lhs_out[0], &rhs_out[0], mask) return np.asarray(lhs_out)[:amt_written], np.asarray(rhs_out)[:amt_written] @@ -421,25 +416,16 @@ def intersect(np.ndarray[DTYPE_t, ndim=1] lhs, else: lhs_out = np.empty(max(lhs.shape[0], rhs.shape[0]), dtype=np.uint64) rhs_out = np.empty(max(lhs.shape[0], rhs.shape[0]), dtype=np.uint64) - print(f"INPUT... Intersect KEEP {mask:0x} {lhs_out_len}|{rhs_out_len}") - print([lhs[i] for i in range(lhs.shape[0])]) - print([rhs[i] for i in range(rhs.shape[0])]) - print(f"Input {lhs[0]}|{rhs[0]}") - print(f" {lhs[1]}|{rhs[1]}") - print(f" {lhs[2]}|{rhs[2]}") - print(f" {lhs[3]}|{rhs[3]}") - print(f"Strides {lhs.strides[0]}|{rhs.strides[0]}") - print(" ----") _gallop_intersect_keep(&lhs[0], &rhs[0], - lhs.shape[0], rhs.shape[0], + lhs.strides[0] / sizeof(DTYPE_t), + rhs.strides[0] / sizeof(DTYPE_t), + lhs.shape[0] * lhs.strides[0] / sizeof(DTYPE_t), + rhs.shape[0] * rhs.strides[0] / sizeof(DTYPE_t), &lhs_out[0], &rhs_out[0], &lhs_out_len, &rhs_out_len, mask) - print(f"OUTPUT... Intersect KEEP {mask:0x} {lhs_out_len}|{rhs_out_len}") lhs_out, rhs_out = np.asarray(lhs_out)[:lhs_out_len], np.asarray(rhs_out)[:rhs_out_len] - print([lhs[lhs_out[i]] for i in range(lhs_out_len)]) - print([rhs[rhs_out[i]] for i in range(rhs_out_len)]) return lhs_out, rhs_out