Skip to content

Commit

Permalink
Add stride to intercept functions
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed May 11, 2024
1 parent af5c587 commit 39d1a30
Showing 1 changed file with 37 additions and 51 deletions.
88 changes: 37 additions & 51 deletions searcharray/roaringish/snp_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -201,13 +201,13 @@ def galloping_search(np.ndarray[DTYPE_t, ndim=1] array,

cdef DTYPE_t _gallop_intersect_drop(DTYPE_t* lhs,
DTYPE_t* rhs,
DTYPE_t lhs_stride,
DTYPE_t rhs_stride,
DTYPE_t lhs_len,
DTYPE_t rhs_len,
DTYPE_t* lhs_out,
DTYPE_t* rhs_out,
DTYPE_t mask=ALL_BITS,
DTYPE_t lhs_base=0,
DTYPE_t rhs_base=0) nogil:
DTYPE_t mask=ALL_BITS) nogil:
"""Two pointer approach to find the intersection of two sorted arrays."""
cdef DTYPE_t* lhs_ptr = &lhs[0]
cdef DTYPE_t* rhs_ptr = &rhs[0]
Expand All @@ -222,45 +222,47 @@ cdef DTYPE_t _gallop_intersect_drop(DTYPE_t* lhs,

# Gallop past the current element
while lhs_ptr < end_lhs_ptr and (lhs_ptr[0] & mask) < (rhs_ptr[0] & mask):
lhs_ptr+=delta
lhs_ptr+= (delta * lhs_stride)
delta *= 2
lhs_ptr -= (delta // 2)
lhs_ptr -= ((delta // 2) * lhs_stride)
delta = 1
while rhs_ptr < end_rhs_ptr and (rhs_ptr[0] & mask) < (lhs_ptr[0] & mask):
rhs_ptr+=delta
rhs_ptr+= (delta * rhs_stride)
delta *= 2
rhs_ptr -= (delta // 2)
rhs_ptr -= ((delta // 2) * rhs_stride)
delta = 1

# Now that we've reset, we just do the naive 2-ptr check
# Then next loop we pickup on exponential search
if (lhs_ptr[0] & mask) < (rhs_ptr[0] & mask):
lhs_ptr = lhs_ptr + 1
lhs_ptr = lhs_ptr + lhs_stride
elif (rhs_ptr[0] & mask) < (lhs_ptr[0] & mask):
rhs_ptr = rhs_ptr + 1
rhs_ptr = rhs_ptr + rhs_stride
else:
# If here values equal, collect
if (last & mask) != (lhs_ptr[0] & mask):
lhs_result_ptr[0] = (lhs_ptr - &lhs[0] + lhs_base)
rhs_result_ptr[0] = (rhs_ptr - &rhs[0] + rhs_base)
lhs_result_ptr[0] = (lhs_ptr - &lhs[0]) / lhs_stride
rhs_result_ptr[0] = (rhs_ptr - &rhs[0]) / rhs_stride
last = lhs_ptr[0]
lhs_result_ptr += 1
rhs_result_ptr += 1
lhs_ptr += 1
rhs_ptr += 1
lhs_ptr += lhs_stride
rhs_ptr += rhs_stride

return lhs_result_ptr - &lhs_out[0]


cdef void _gallop_intersect_keep(DTYPE_t* lhs,
DTYPE_t* rhs,
DTYPE_t lhs_stride,
DTYPE_t rhs_stride,
DTYPE_t lhs_len,
DTYPE_t rhs_len,
DTYPE_t* lhs_out,
DTYPE_t* rhs_out,
DTYPE_t* lhs_out_len,
DTYPE_t* rhs_out_len,
DTYPE_t mask=ALL_BITS) noexcept:
DTYPE_t mask=ALL_BITS) noexcept nogil:
"""Two pointer approach to find the intersection of two sorted arrays."""
cdef DTYPE_t* lhs_ptr = &lhs[0]
cdef DTYPE_t* rhs_ptr = &rhs[0]
Expand All @@ -273,47 +275,37 @@ cdef void _gallop_intersect_keep(DTYPE_t* lhs,
cdef DTYPE_t* lhs_result_ptr = &lhs_out[0]
cdef DTYPE_t* rhs_result_ptr = &rhs_out[0]

print(f"Input {lhs[0]}|{rhs[0]}")
print(f" {lhs[1]}|{rhs[1]}")
print(f" {lhs[2]}|{rhs[2]}")
print(f" {lhs[3]}|{rhs[3]}")

while lhs_ptr < end_lhs_ptr and rhs_ptr < end_rhs_ptr:
while lhs_ptr < end_lhs_ptr or rhs_ptr < end_rhs_ptr:
# Gallop past the current element
while lhs_ptr < end_lhs_ptr and (lhs_ptr[0] & mask) < (rhs_ptr[0] & mask):
lhs_ptr += delta
lhs_ptr += (delta * lhs_stride)
delta <<= 1
lhs_ptr -= (delta >> 1)
lhs_ptr -= ((delta >> 1) * lhs_stride)
delta = 1
while rhs_ptr < end_rhs_ptr and (rhs_ptr[0] & mask) < (lhs_ptr[0] & mask):
rhs_ptr += delta
rhs_ptr += (delta * rhs_stride)
delta <<= 1
rhs_ptr -= (delta >> 1)
rhs_ptr -= ((delta >> 1) * rhs_stride)
delta = 1
print(f" At: {lhs_ptr - &lhs[0]}|{rhs_ptr - &rhs[0]}")
print(f"Values: {lhs_ptr[0]} vs {rhs_ptr[0]}")

# Now that we've reset, we just do the naive 2-ptr check
# Then next loop we pickup on exponential search
if (lhs_ptr[0] & mask) < (rhs_ptr[0] & mask):
lhs_ptr += 1
lhs_ptr += lhs_stride
elif (rhs_ptr[0] & mask) < (lhs_ptr[0] & mask):
rhs_ptr += 1
rhs_ptr += rhs_stride
else:
target = lhs_ptr[0] & mask
print("---")
print(f"Collecting target: {target}")
# Store all LHS indices equal to RHS
while (lhs_ptr[0] & mask) == target and lhs_ptr < end_lhs_ptr:
lhs_result_ptr[0] = lhs_ptr - &lhs[0]; lhs_result_ptr += 1
print(f"lhs_ptr: {lhs_ptr[0]} target: {target} -- saving idx: {lhs_ptr - &lhs[0]}")
lhs_ptr += 1
lhs_result_ptr[0] = (lhs_ptr - &lhs[0]) / lhs_stride
lhs_result_ptr += 1
lhs_ptr += lhs_stride
# Store all RHS equal to LHS
while (rhs_ptr[0] & mask) == target and rhs_ptr < end_rhs_ptr:
rhs_result_ptr[0] = rhs_ptr - &rhs[0]; rhs_result_ptr += 1
print(f"rhs_ptr: {rhs_ptr[0]} target: {target} -- saving idx: {rhs_ptr - &rhs[0]}")
rhs_ptr += 1
print("---")
rhs_result_ptr[0] = (rhs_ptr - &rhs[0]) / rhs_stride
rhs_result_ptr += 1
rhs_ptr += rhs_stride

# If delta
# Either we read past the array, or
Expand Down Expand Up @@ -413,33 +405,27 @@ def intersect(np.ndarray[DTYPE_t, ndim=1] lhs,
lhs_out = np.empty(min(lhs.shape[0], rhs.shape[0]), dtype=np.uint64)
rhs_out = np.empty(min(lhs.shape[0], rhs.shape[0]), dtype=np.uint64)
amt_written = _gallop_intersect_drop(&lhs[0], &rhs[0],
lhs.shape[0], rhs.shape[0],
lhs.strides[0] / sizeof(DTYPE_t),
rhs.strides[0] / sizeof(DTYPE_t),
lhs.shape[0] * lhs.strides[0] / sizeof(DTYPE_t),
rhs.shape[0] * rhs.strides[0] / sizeof(DTYPE_t),
&lhs_out[0], &rhs_out[0],
mask)
return np.asarray(lhs_out)[:amt_written], np.asarray(rhs_out)[:amt_written]

else:
lhs_out = np.empty(max(lhs.shape[0], rhs.shape[0]), dtype=np.uint64)
rhs_out = np.empty(max(lhs.shape[0], rhs.shape[0]), dtype=np.uint64)
print(f"INPUT... Intersect KEEP {mask:0x} {lhs_out_len}|{rhs_out_len}")
print([lhs[i] for i in range(lhs.shape[0])])
print([rhs[i] for i in range(rhs.shape[0])])
print(f"Input {lhs[0]}|{rhs[0]}")
print(f" {lhs[1]}|{rhs[1]}")
print(f" {lhs[2]}|{rhs[2]}")
print(f" {lhs[3]}|{rhs[3]}")
print(f"Strides {lhs.strides[0]}|{rhs.strides[0]}")
print(" ----")

_gallop_intersect_keep(&lhs[0], &rhs[0],
lhs.shape[0], rhs.shape[0],
lhs.strides[0] / sizeof(DTYPE_t),
rhs.strides[0] / sizeof(DTYPE_t),
lhs.shape[0] * lhs.strides[0] / sizeof(DTYPE_t),
rhs.shape[0] * rhs.strides[0] / sizeof(DTYPE_t),
&lhs_out[0], &rhs_out[0],
&lhs_out_len, &rhs_out_len,
mask)
print(f"OUTPUT... Intersect KEEP {mask:0x} {lhs_out_len}|{rhs_out_len}")
lhs_out, rhs_out = np.asarray(lhs_out)[:lhs_out_len], np.asarray(rhs_out)[:rhs_out_len]
print([lhs[lhs_out[i]] for i in range(lhs_out_len)])
print([rhs[rhs_out[i]] for i in range(rhs_out_len)])
return lhs_out, rhs_out


Expand Down

0 comments on commit 39d1a30

Please sign in to comment.