Skip to content

Commit

Permalink
Add an intersection method that intersects keys first
Browse files Browse the repository at this point in the history
  • Loading branch information
softwaredoug committed Dec 7, 2024
1 parent ad39e54 commit 526dbbc
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 17 deletions.
14 changes: 14 additions & 0 deletions searcharray/roaringish/intersect.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,17 @@ def intersect_with_adjacents(lhs: NDArray[np.uint64],
mask: np.uint64 = ALL_BITS) -> Tuple[np.ndarray, np.ndarray,
np.ndarray, np.ndarray]:
...


def int_w_index(lhs: NDArray[np.uint64],
rhs: NDArray[np.uint64],
lhs_index: NDArray[np.uint64],
rhs_index: NDArray[np.uint64],
index_mask: np.uint64 = ALL_BITS,
mask: np.uint64 = ALL_BITS) -> Tuple[np.ndarray, np.ndarray]:
...


def build_intersect_index(arr: NDArray[np.uint64],
mask: np.uint64) -> NDArray[np.uint64]:
...
140 changes: 126 additions & 14 deletions searcharray/roaringish/intersect.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ cdef DTYPE_t _gallop_int_and_adj_drop(intersect_args_t args,
DTYPE_t* adj_lhs_out,
DTYPE_t* adj_rhs_out,
DTYPE_t* adj_out_len) nogil:
"""Two pointer approach to find the intersection of two sorted arrays."""
"""Galloping approach to find the intersection w/ adjacents of two sorted arrays."""
cdef DTYPE_t* lhs_ptr = &args.lhs[0]
cdef DTYPE_t* rhs_ptr = &args.rhs[0]
cdef DTYPE_t* end_lhs_ptr = &args.lhs[args.lhs_len]
Expand All @@ -238,18 +238,11 @@ cdef DTYPE_t _gallop_int_and_adj_drop(intersect_args_t args,
cdef DTYPE_t* lhs_adj_result_ptr = &adj_lhs_out[0]
cdef DTYPE_t* rhs_adj_result_ptr = &adj_rhs_out[0]

cdef uint64_t gallop_time = 0
cdef uint64_t gallop_start = 0

cdef uint64_t collect_time = 0
cdef uint64_t collect_start = 0

while lhs_ptr < end_lhs_ptr and rhs_ptr < end_rhs_ptr:

# Gallop to adjacent or equal value
# if value_lhs < value_rhs - delta:
# Gallop past the current element
gallop_start = timestamp()
if (lhs_ptr[0] & args.mask) != (rhs_ptr[0] & args.mask):
while lhs_ptr < end_lhs_ptr and ((lhs_ptr[0] & args.mask) + delta) < (rhs_ptr[0] & args.mask):
lhs_ptr += (gallop * args.lhs_stride)
Expand All @@ -263,8 +256,6 @@ cdef DTYPE_t _gallop_int_and_adj_drop(intersect_args_t args,
gallop = 1
# Now lhs is at or before RHS - delta
# RHS is 4, LHS is at most 3
gallop_time += (timestamp() - gallop_start)
collect_start = timestamp()
# Collect adjacent avalues
if ((lhs_ptr[0] & args.mask) + delta) == ((rhs_ptr[0] & args.mask)):
if (last_adj & args.mask) != (lhs_ptr[0] & args.mask):
Expand All @@ -289,10 +280,6 @@ cdef DTYPE_t _gallop_int_and_adj_drop(intersect_args_t args,
lhs_result_ptr += 1
rhs_result_ptr += 1
rhs_ptr += args.rhs_stride
collect_time += (timestamp() - collect_start)

print_elapsed(gallop_time, "Gallop ")
print_elapsed(collect_time, "Collect")

adj_out_len[0] = lhs_adj_result_ptr - &adj_lhs_out[0]
return lhs_result_ptr - &args.lhs_out[0]
Expand Down Expand Up @@ -412,3 +399,128 @@ def intersect_with_adjacents(np.ndarray[DTYPE_t, ndim=1] lhs,
&adj_out_len)
return (np.asarray(lhs_out)[:amt_written], np.asarray(rhs_out)[:amt_written],
np.asarray(adj_lhs_out)[:adj_out_len], np.asarray(adj_rhs_out)[:adj_out_len])



cdef DTYPE_t _int_w_index(intersect_args_t args,
DTYPE_t index_mask,
DTYPE_t* lhs_index,
DTYPE_t lhs_index_len,
DTYPE_t* rhs_index,
DTYPE_t rhs_index_len) nogil:
"""Two pointer intersect the index first THEN intersect the lhs / rhs within those indices."""
cdef DTYPE_t* lhs_ptr = &args.lhs[0]
cdef DTYPE_t* rhs_ptr = &args.rhs[0]
cdef DTYPE_t* lhs_index_ptr = &lhs_index[0]
cdef DTYPE_t* rhs_index_ptr = &rhs_index[0]
cdef DTYPE_t* end_lhs_ptr = &args.lhs[args.lhs_len]
cdef DTYPE_t* end_rhs_ptr = &args.rhs[args.rhs_len]
cdef DTYPE_t* end_lhs_index_ptr = &lhs_index[lhs_index_len]
cdef DTYPE_t* end_rhs_index_ptr = &rhs_index[rhs_index_len]
cdef DTYPE_t* lhs_result_ptr = &args.lhs_out[0]
cdef DTYPE_t* rhs_result_ptr = &args.rhs_out[0]
cdef DTYPE_t lhs_idx = 0
cdef DTYPE_t rhs_idx = 0
cdef DTYPE_t lhs_curr_end = 0
cdef DTYPE_t rhs_curr_end = 0
cdef DTYPE_t index_lsb_mask = ~index_mask

while lhs_index_ptr < end_lhs_index_ptr and rhs_index_ptr < end_rhs_index_ptr:
if (lhs_index_ptr[0] & index_mask) < (rhs_index_ptr[0] & index_mask):
lhs_index_ptr += 1
elif (rhs_index_ptr[0] & index_mask) < (lhs_index_ptr[0] & index_mask):
rhs_index_ptr += 1
else:
# Now two pointer intersect within lhs_index_ptr -> lhs_index_ptr[1]
lhs_idx = lhs_index_ptr[0] & index_lsb_mask
rhs_idx = rhs_index_ptr[0] & index_lsb_mask
lhs_curr_end = (end_lhs_ptr - &args.lhs[0])
rhs_curr_end = (end_rhs_ptr - &args.rhs[0])
if lhs_index_ptr + 1 < end_lhs_index_ptr:
lhs_curr_end = ((lhs_index_ptr + 1)[0] & index_lsb_mask)
if rhs_index_ptr + 1 < end_rhs_index_ptr:
rhs_curr_end = ((rhs_index_ptr + 1)[0] & index_lsb_mask)

# Two pointer intesect between lhs_index_ptr and lhs_index_end w/ rhs_index_ptr and rhs_index_end
while lhs_idx < lhs_curr_end and rhs_idx < rhs_curr_end:
if (args.lhs[lhs_idx] & args.mask) < (args.rhs[rhs_idx] & args.mask):
lhs_idx += 1
elif (args.rhs[rhs_idx] & args.mask) < (args.lhs[lhs_idx] & args.mask):
rhs_idx += 1
else:
lhs_result_ptr[0] = lhs_idx
rhs_result_ptr[0] = rhs_idx
lhs_result_ptr += 1
rhs_result_ptr += 1
lhs_idx += 1
rhs_idx += 1

lhs_index_ptr += 1
rhs_index_ptr += 1

return lhs_result_ptr - &args.lhs_out[0]


cdef DTYPE_t _build_intersect_index(DTYPE_t* arr,
DTYPE_t arr_len,
DTYPE_t mask,
DTYPE_t* idx_out) nogil:
cdef DTYPE_t i = 0
cdef DTYPE_t headerVal = 0xFFFFFFFFFFFFFFFF
cdef DTYPE_t lastHeaderVal = 0
cdef DTYPE_t* currIdxOut = &idx_out[0]
for i in range(arr_len):
headerVal = arr[i] & mask
if headerVal != lastHeaderVal:
currIdxOut[0] = (headerVal | i)
currIdxOut += 1
lastHeaderVal = headerVal
return currIdxOut - &idx_out[0]



def int_w_index(np.ndarray[DTYPE_t, ndim=1] lhs,
np.ndarray[DTYPE_t, ndim=1] rhs,
np.ndarray[DTYPE_t, ndim=1] lhs_index,
np.ndarray[DTYPE_t, ndim=1] rhs_index,
DTYPE_t index_mask=ALL_BITS,
DTYPE_t mask=ALL_BITS):
cdef np.uint64_t[:] lhs_out
cdef np.uint64_t[:] rhs_out
cdef intersect_args_t args
cdef DTYPE_t adj_out_len = 0
cdef DTYPE_t* lhs_index_ptr = &lhs_index[0]
cdef DTYPE_t* rhs_index_ptr = &rhs_index[0]
cdef DTYPE_t lhs_index_len = lhs_index.shape[0]
cdef DTYPE_t rhs_index_len = rhs_index.shape[0]

lhs_out = np.empty(min(lhs.shape[0], rhs.shape[0]), dtype=np.uint64)
rhs_out = np.empty(min(lhs.shape[0], rhs.shape[0]), dtype=np.uint64)

args.mask = mask
args.lhs = &lhs[0]
args.rhs = &rhs[0]
args.lhs_len = lhs.shape[0] * lhs.strides[0] / sizeof(DTYPE_t)
args.rhs_len = rhs.shape[0] * rhs.strides[0] / sizeof(DTYPE_t)
args.lhs_out = &lhs_out[0]
args.rhs_out = &rhs_out[0]

with nogil:
amt_written = _int_w_index(args,
index_mask,
lhs_index_ptr, lhs_index_len,
rhs_index_ptr, rhs_index_len)

return (np.asarray(lhs_out)[:amt_written], np.asarray(rhs_out)[:amt_written])




def build_intersect_index(np.ndarray[DTYPE_t, ndim=1] arr,
DTYPE_t mask=ALL_BITS):
cdef np.uint64_t[:] idx_out = np.empty(arr.shape[0], dtype=np.uint64)
cdef DTYPE_t amt_written = 0
cdef DTYPE_t* arr_ptr = &arr[0]
with nogil:
amt_written = _build_intersect_index(arr_ptr, arr.shape[0], mask, &idx_out[0])
return np.asarray(idx_out[:amt_written])
30 changes: 27 additions & 3 deletions test/test_snp_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from searcharray.roaringish.search import binary_search, galloping_search, count_odds
from searcharray.roaringish.unique import unique
from searcharray.roaringish.merge import merge
from searcharray.roaringish.intersect import intersect, adjacent, intersect_with_adjacents
from searcharray.roaringish.intersect import intersect, adjacent, intersect_with_adjacents, build_intersect_index, int_w_index
from test_utils import w_scenarios
from test_utils import Profiler, profile_enabled

Expand Down Expand Up @@ -263,6 +263,23 @@ def intersect_many():
profiler.run(intersect_many)


@pytest.mark.parametrize("suffix", [128, 185, 24179, 27685, 44358, 45907, 90596])
def test_indexed_intersect(suffix):
print(f"Running with {suffix}")
lhs = np.load(f"fixtures/lhs_{suffix}.npy")
rhs = np.load(f"fixtures/rhs_{suffix}.npy")
mask = np.load(f"fixtures/mask_{suffix}.npy")
# Set 28 bits to 1
key_mask = np.uint64(0xFFFFFFF000000000)
lhs_index = build_intersect_index(lhs, key_mask)
rhs_index = build_intersect_index(rhs, key_mask)

lhs_out_windex, rhs_out_windex = int_w_index(lhs, rhs, lhs_index, rhs_index, key_mask, mask)
lhs_out_int, rhs_out_out = intersect(lhs, rhs, mask)
assert np.all(lhs_out_windex == lhs_out_int)
assert np.all(rhs_out_windex == rhs_out_out)


@pytest.mark.skipif(not profile_enabled, reason="Profiling disabled")
def test_profile_masked_intersect_sparse_sparse(benchmark):
profiler = Profiler(benchmark)
Expand Down Expand Up @@ -329,10 +346,16 @@ def test_profile_masked_saved(suffix, benchmark):
lhs = np.load(f"fixtures/lhs_{suffix}.npy")
rhs = np.load(f"fixtures/rhs_{suffix}.npy")
mask = np.load(f"fixtures/mask_{suffix}.npy")
print(lhs.shape, rhs.shape)
# Set 28 bits to 1
key_mask = np.uint64(0xFFFFFFF000000000)
lhs_index = build_intersect_index(lhs, key_mask)
rhs_index = build_intersect_index(rhs, key_mask)

def with_indexed():
lhs_out, rhs_out = int_w_index(lhs, rhs, lhs_index, rhs_index, key_mask, mask)

def with_snp_ops():
intersect(lhs, rhs, mask)
lhs_out, rhs_out = intersect(lhs, rhs, mask)

def with_snp():
snp.intersect(lhs >> 18, rhs >> 18, indices=True, duplicates=snp.DROP)
Expand All @@ -345,6 +368,7 @@ def intersect_many():
with_snp_ops()
with_snp()
baseline()
with_indexed()

profiler.run(intersect_many)

Expand Down

0 comments on commit 526dbbc

Please sign in to comment.