From 526dbbcd67117f1145acbf80528f85f75686060a Mon Sep 17 00:00:00 2001 From: Doug Turnbull Date: Sat, 7 Dec 2024 16:22:01 -0500 Subject: [PATCH] Add an intersection method that intersects keys first --- searcharray/roaringish/intersect.pyi | 14 +++ searcharray/roaringish/intersect.pyx | 140 ++++++++++++++++++++++++--- test/test_snp_ops.py | 30 +++++- 3 files changed, 167 insertions(+), 17 deletions(-) diff --git a/searcharray/roaringish/intersect.pyi b/searcharray/roaringish/intersect.pyi index 55f623a..eaddf22 100644 --- a/searcharray/roaringish/intersect.pyi +++ b/searcharray/roaringish/intersect.pyi @@ -23,3 +23,17 @@ def intersect_with_adjacents(lhs: NDArray[np.uint64], mask: np.uint64 = ALL_BITS) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ... + + +def int_w_index(lhs: NDArray[np.uint64], + rhs: NDArray[np.uint64], + lhs_index: NDArray[np.uint64], + rhs_index: NDArray[np.uint64], + index_mask: np.uint64 = ALL_BITS, + mask: np.uint64 = ALL_BITS) -> Tuple[np.ndarray, np.ndarray]: + ... + + +def build_intersect_index(arr: NDArray[np.uint64], + mask: np.uint64) -> NDArray[np.uint64]: + ... diff --git a/searcharray/roaringish/intersect.pyx b/searcharray/roaringish/intersect.pyx index dda76aa..9827fd2 100644 --- a/searcharray/roaringish/intersect.pyx +++ b/searcharray/roaringish/intersect.pyx @@ -225,7 +225,7 @@ cdef DTYPE_t _gallop_int_and_adj_drop(intersect_args_t args, DTYPE_t* adj_lhs_out, DTYPE_t* adj_rhs_out, DTYPE_t* adj_out_len) nogil: - """Two pointer approach to find the intersection of two sorted arrays.""" + """Galloping approach to find the intersection w/ adjacents of two sorted arrays.""" cdef DTYPE_t* lhs_ptr = &args.lhs[0] cdef DTYPE_t* rhs_ptr = &args.rhs[0] cdef DTYPE_t* end_lhs_ptr = &args.lhs[args.lhs_len] @@ -238,18 +238,11 @@ cdef DTYPE_t _gallop_int_and_adj_drop(intersect_args_t args, cdef DTYPE_t* lhs_adj_result_ptr = &adj_lhs_out[0] cdef DTYPE_t* rhs_adj_result_ptr = &adj_rhs_out[0] - cdef uint64_t gallop_time = 0 - cdef uint64_t gallop_start = 0 - - cdef uint64_t collect_time = 0 - cdef uint64_t collect_start = 0 - while lhs_ptr < end_lhs_ptr and rhs_ptr < end_rhs_ptr: # Gallop to adjacent or equal value # if value_lhs < value_rhs - delta: # Gallop past the current element - gallop_start = timestamp() if (lhs_ptr[0] & args.mask) != (rhs_ptr[0] & args.mask): while lhs_ptr < end_lhs_ptr and ((lhs_ptr[0] & args.mask) + delta) < (rhs_ptr[0] & args.mask): lhs_ptr += (gallop * args.lhs_stride) @@ -263,8 +256,6 @@ cdef DTYPE_t _gallop_int_and_adj_drop(intersect_args_t args, gallop = 1 # Now lhs is at or before RHS - delta # RHS is 4, LHS is at most 3 - gallop_time += (timestamp() - gallop_start) - collect_start = timestamp() # Collect adjacent avalues if ((lhs_ptr[0] & args.mask) + delta) == ((rhs_ptr[0] & args.mask)): if (last_adj & args.mask) != (lhs_ptr[0] & args.mask): @@ -289,10 +280,6 @@ cdef DTYPE_t _gallop_int_and_adj_drop(intersect_args_t args, lhs_result_ptr += 1 rhs_result_ptr += 1 rhs_ptr += args.rhs_stride - collect_time += (timestamp() - collect_start) - - print_elapsed(gallop_time, "Gallop ") - print_elapsed(collect_time, "Collect") adj_out_len[0] = lhs_adj_result_ptr - &adj_lhs_out[0] return lhs_result_ptr - &args.lhs_out[0] @@ -412,3 +399,128 @@ def intersect_with_adjacents(np.ndarray[DTYPE_t, ndim=1] lhs, &adj_out_len) return (np.asarray(lhs_out)[:amt_written], np.asarray(rhs_out)[:amt_written], np.asarray(adj_lhs_out)[:adj_out_len], np.asarray(adj_rhs_out)[:adj_out_len]) + + + +cdef DTYPE_t _int_w_index(intersect_args_t args, + DTYPE_t index_mask, + DTYPE_t* lhs_index, + DTYPE_t lhs_index_len, + DTYPE_t* rhs_index, + DTYPE_t rhs_index_len) nogil: + """Two pointer intersect the index first THEN intersect the lhs / rhs within those indices.""" + cdef DTYPE_t* lhs_ptr = &args.lhs[0] + cdef DTYPE_t* rhs_ptr = &args.rhs[0] + cdef DTYPE_t* lhs_index_ptr = &lhs_index[0] + cdef DTYPE_t* rhs_index_ptr = &rhs_index[0] + cdef DTYPE_t* end_lhs_ptr = &args.lhs[args.lhs_len] + cdef DTYPE_t* end_rhs_ptr = &args.rhs[args.rhs_len] + cdef DTYPE_t* end_lhs_index_ptr = &lhs_index[lhs_index_len] + cdef DTYPE_t* end_rhs_index_ptr = &rhs_index[rhs_index_len] + cdef DTYPE_t* lhs_result_ptr = &args.lhs_out[0] + cdef DTYPE_t* rhs_result_ptr = &args.rhs_out[0] + cdef DTYPE_t lhs_idx = 0 + cdef DTYPE_t rhs_idx = 0 + cdef DTYPE_t lhs_curr_end = 0 + cdef DTYPE_t rhs_curr_end = 0 + cdef DTYPE_t index_lsb_mask = ~index_mask + + while lhs_index_ptr < end_lhs_index_ptr and rhs_index_ptr < end_rhs_index_ptr: + if (lhs_index_ptr[0] & index_mask) < (rhs_index_ptr[0] & index_mask): + lhs_index_ptr += 1 + elif (rhs_index_ptr[0] & index_mask) < (lhs_index_ptr[0] & index_mask): + rhs_index_ptr += 1 + else: + # Now two pointer intersect within lhs_index_ptr -> lhs_index_ptr[1] + lhs_idx = lhs_index_ptr[0] & index_lsb_mask + rhs_idx = rhs_index_ptr[0] & index_lsb_mask + lhs_curr_end = (end_lhs_ptr - &args.lhs[0]) + rhs_curr_end = (end_rhs_ptr - &args.rhs[0]) + if lhs_index_ptr + 1 < end_lhs_index_ptr: + lhs_curr_end = ((lhs_index_ptr + 1)[0] & index_lsb_mask) + if rhs_index_ptr + 1 < end_rhs_index_ptr: + rhs_curr_end = ((rhs_index_ptr + 1)[0] & index_lsb_mask) + + # Two pointer intesect between lhs_index_ptr and lhs_index_end w/ rhs_index_ptr and rhs_index_end + while lhs_idx < lhs_curr_end and rhs_idx < rhs_curr_end: + if (args.lhs[lhs_idx] & args.mask) < (args.rhs[rhs_idx] & args.mask): + lhs_idx += 1 + elif (args.rhs[rhs_idx] & args.mask) < (args.lhs[lhs_idx] & args.mask): + rhs_idx += 1 + else: + lhs_result_ptr[0] = lhs_idx + rhs_result_ptr[0] = rhs_idx + lhs_result_ptr += 1 + rhs_result_ptr += 1 + lhs_idx += 1 + rhs_idx += 1 + + lhs_index_ptr += 1 + rhs_index_ptr += 1 + + return lhs_result_ptr - &args.lhs_out[0] + + +cdef DTYPE_t _build_intersect_index(DTYPE_t* arr, + DTYPE_t arr_len, + DTYPE_t mask, + DTYPE_t* idx_out) nogil: + cdef DTYPE_t i = 0 + cdef DTYPE_t headerVal = 0xFFFFFFFFFFFFFFFF + cdef DTYPE_t lastHeaderVal = 0 + cdef DTYPE_t* currIdxOut = &idx_out[0] + for i in range(arr_len): + headerVal = arr[i] & mask + if headerVal != lastHeaderVal: + currIdxOut[0] = (headerVal | i) + currIdxOut += 1 + lastHeaderVal = headerVal + return currIdxOut - &idx_out[0] + + + +def int_w_index(np.ndarray[DTYPE_t, ndim=1] lhs, + np.ndarray[DTYPE_t, ndim=1] rhs, + np.ndarray[DTYPE_t, ndim=1] lhs_index, + np.ndarray[DTYPE_t, ndim=1] rhs_index, + DTYPE_t index_mask=ALL_BITS, + DTYPE_t mask=ALL_BITS): + cdef np.uint64_t[:] lhs_out + cdef np.uint64_t[:] rhs_out + cdef intersect_args_t args + cdef DTYPE_t adj_out_len = 0 + cdef DTYPE_t* lhs_index_ptr = &lhs_index[0] + cdef DTYPE_t* rhs_index_ptr = &rhs_index[0] + cdef DTYPE_t lhs_index_len = lhs_index.shape[0] + cdef DTYPE_t rhs_index_len = rhs_index.shape[0] + + lhs_out = np.empty(min(lhs.shape[0], rhs.shape[0]), dtype=np.uint64) + rhs_out = np.empty(min(lhs.shape[0], rhs.shape[0]), dtype=np.uint64) + + args.mask = mask + args.lhs = &lhs[0] + args.rhs = &rhs[0] + args.lhs_len = lhs.shape[0] * lhs.strides[0] / sizeof(DTYPE_t) + args.rhs_len = rhs.shape[0] * rhs.strides[0] / sizeof(DTYPE_t) + args.lhs_out = &lhs_out[0] + args.rhs_out = &rhs_out[0] + + with nogil: + amt_written = _int_w_index(args, + index_mask, + lhs_index_ptr, lhs_index_len, + rhs_index_ptr, rhs_index_len) + + return (np.asarray(lhs_out)[:amt_written], np.asarray(rhs_out)[:amt_written]) + + + + +def build_intersect_index(np.ndarray[DTYPE_t, ndim=1] arr, + DTYPE_t mask=ALL_BITS): + cdef np.uint64_t[:] idx_out = np.empty(arr.shape[0], dtype=np.uint64) + cdef DTYPE_t amt_written = 0 + cdef DTYPE_t* arr_ptr = &arr[0] + with nogil: + amt_written = _build_intersect_index(arr_ptr, arr.shape[0], mask, &idx_out[0]) + return np.asarray(idx_out[:amt_written]) diff --git a/test/test_snp_ops.py b/test/test_snp_ops.py index 54e9006..abe3044 100644 --- a/test/test_snp_ops.py +++ b/test/test_snp_ops.py @@ -5,7 +5,7 @@ from searcharray.roaringish.search import binary_search, galloping_search, count_odds from searcharray.roaringish.unique import unique from searcharray.roaringish.merge import merge -from searcharray.roaringish.intersect import intersect, adjacent, intersect_with_adjacents +from searcharray.roaringish.intersect import intersect, adjacent, intersect_with_adjacents, build_intersect_index, int_w_index from test_utils import w_scenarios from test_utils import Profiler, profile_enabled @@ -263,6 +263,23 @@ def intersect_many(): profiler.run(intersect_many) +@pytest.mark.parametrize("suffix", [128, 185, 24179, 27685, 44358, 45907, 90596]) +def test_indexed_intersect(suffix): + print(f"Running with {suffix}") + lhs = np.load(f"fixtures/lhs_{suffix}.npy") + rhs = np.load(f"fixtures/rhs_{suffix}.npy") + mask = np.load(f"fixtures/mask_{suffix}.npy") + # Set 28 bits to 1 + key_mask = np.uint64(0xFFFFFFF000000000) + lhs_index = build_intersect_index(lhs, key_mask) + rhs_index = build_intersect_index(rhs, key_mask) + + lhs_out_windex, rhs_out_windex = int_w_index(lhs, rhs, lhs_index, rhs_index, key_mask, mask) + lhs_out_int, rhs_out_out = intersect(lhs, rhs, mask) + assert np.all(lhs_out_windex == lhs_out_int) + assert np.all(rhs_out_windex == rhs_out_out) + + @pytest.mark.skipif(not profile_enabled, reason="Profiling disabled") def test_profile_masked_intersect_sparse_sparse(benchmark): profiler = Profiler(benchmark) @@ -329,10 +346,16 @@ def test_profile_masked_saved(suffix, benchmark): lhs = np.load(f"fixtures/lhs_{suffix}.npy") rhs = np.load(f"fixtures/rhs_{suffix}.npy") mask = np.load(f"fixtures/mask_{suffix}.npy") - print(lhs.shape, rhs.shape) + # Set 28 bits to 1 + key_mask = np.uint64(0xFFFFFFF000000000) + lhs_index = build_intersect_index(lhs, key_mask) + rhs_index = build_intersect_index(rhs, key_mask) + + def with_indexed(): + lhs_out, rhs_out = int_w_index(lhs, rhs, lhs_index, rhs_index, key_mask, mask) def with_snp_ops(): - intersect(lhs, rhs, mask) + lhs_out, rhs_out = intersect(lhs, rhs, mask) def with_snp(): snp.intersect(lhs >> 18, rhs >> 18, indices=True, duplicates=snp.DROP) @@ -345,6 +368,7 @@ def intersect_many(): with_snp_ops() with_snp() baseline() + with_indexed() profiler.run(intersect_many)