Skip to content

Commit

Permalink
Improve: SWAR search for normal order search
Browse files Browse the repository at this point in the history
On AWS Graviton 3 for `std::string`:
- needle of length 2: 1.1 GB/s
- needle of length 3: 1.3 GB/s
- needle of length 4: 2.0 GB/s

On AWS Graviton 3 for `memmem` LibC func:
- needle of length 2: 1.2 GB/s
- needle of length 3: 1.5 GB/s
- needle of length 4: 2.6 GB/s

On AWS Graviton 3 for StringZilla SWAR:
- needle of length 2: 2.7 GB/s
- needle of length 3: 1.1 GB/s
- needle of length 4: 2.4 GB/s

On AWS Graviton 3 for StringZilla NEON:
- needle of length 2: 4.6 GB/s
- needle of length 3: 6.1 GB/s
- needle of length 4: 11 GB/s
  • Loading branch information
ashvardanian committed Jan 16, 2024
1 parent 4cc0f5f commit 8d0bca6
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 44 deletions.
160 changes: 117 additions & 43 deletions include/stringzilla/stringzilla.h
Original file line number Diff line number Diff line change
Expand Up @@ -1299,15 +1299,12 @@ SZ_PUBLIC sz_cptr_t sz_find_last_from_set_serial(sz_cptr_t text, sz_size_t lengt
return NULL;
}

/**
* @brief Byte-level lexicographic order comparison of two strings.
*/
SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length) {
sz_ordering_t ordering_lookup[2] = {sz_greater_k, sz_less_k};
#if SZ_USE_MISALIGNED_LOADS
sz_bool_t a_shorter = (sz_bool_t)(a_length < b_length);
sz_size_t min_length = a_shorter ? a_length : b_length;
sz_cptr_t min_end = a + min_length;
#if SZ_USE_MISALIGNED_LOADS
for (sz_u64_vec_t a_vec, b_vec; a + 8 <= min_end; a += 8, b += 8) {
a_vec.u64 = sz_u64_bytes_reverse(sz_u64_load(a).u64);
b_vec.u64 = sz_u64_bytes_reverse(sz_u64_load(b).u64);
Expand All @@ -1323,14 +1320,14 @@ SZ_PUBLIC sz_ordering_t sz_order_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr
* @brief Byte-level equality comparison between two 64-bit integers.
* @return 64-bit integer, where every top bit in each byte signifies a match.
*/
SZ_INTERNAL sz_u64_t sz_u64_each_byte_equal(sz_u64_t a, sz_u64_t b) {
sz_u64_t match_indicators = ~(a ^ b);
SZ_INTERNAL sz_u64_vec_t _sz_u64_each_byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) {
sz_u64_vec_t vec;
vec.u64 = ~(a.u64 ^ b.u64);
// The match is valid, if every bit within each byte is set.
// For that take the bottom 7 bits of each byte, add one to them,
// and if this sets the top bit to one, then all the 7 bits are ones as well.
match_indicators = ((match_indicators & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) &
((match_indicators & 0x8080808080808080ull));
return match_indicators;
vec.u64 = ((vec.u64 & 0x7F7F7F7F7F7F7F7Full) + 0x0101010101010101ull) & ((vec.u64 & 0x8080808080808080ull));
return vec;
}

/**
Expand All @@ -1343,18 +1340,21 @@ SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr
if (!h_length) return NULL;
sz_cptr_t const h_end = h + h_length;

#if !SZ_USE_MISALIGNED_LOADS
// Process the misaligned head, to void UB on unaligned 64-bit loads.
for (; ((sz_size_t)h & 7ull) && h < h_end; ++h)
if (*h == *n) return h;
#endif

// Broadcast the n into every byte of a 64-bit integer to use SWAR
// techniques and process eight characters at a time.
sz_u64_vec_t h_vec, n_vec;
sz_u64_vec_t h_vec, n_vec, match_vec;
match_vec.u64 = 0;
n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull;
for (; h + 8 <= h_end; h += 8) {
h_vec.u64 = *(sz_u64_t const *)h;
sz_u64_t match_indicators = sz_u64_each_byte_equal(h_vec.u64, n_vec.u64);
if (match_indicators != 0) return h + sz_u64_ctz(match_indicators) / 8;
match_vec = _sz_u64_each_byte_equal(h_vec, n_vec);
if (match_vec.u64) return h + sz_u64_ctz(match_vec.u64) / 8;
}

// Handle the misaligned tail.
Expand All @@ -1368,78 +1368,150 @@ SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr
* This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time.
* Identical to `memrchr(haystack, needle[0], haystack_length)`.
*/
sz_cptr_t sz_find_last_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t needle) {
sz_cptr_t sz_find_last_byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {

if (!h_length) return NULL;
sz_cptr_t const h_start = h;

// Reposition the `h` pointer to the end, as we will be walking backwards.
h = h + h_length - 1;

#if !SZ_USE_MISALIGNED_LOADS
// Process the misaligned head, to void UB on unaligned 64-bit loads.
for (; ((sz_size_t)(h + 1) & 7ull) && h >= h_start; --h)
if (*h == *needle) return h;
if (*h == *n) return h;
#endif

// Broadcast the needle into every byte of a 64-bit integer to use SWAR
// Broadcast the n into every byte of a 64-bit integer to use SWAR
// techniques and process eight characters at a time.
sz_u64_vec_t h_vec, n_vec;
n_vec.u64 = (sz_u64_t)needle[0] * 0x0101010101010101ull;
sz_u64_vec_t h_vec, n_vec, match_vec;
n_vec.u64 = (sz_u64_t)n[0] * 0x0101010101010101ull;
for (; h >= h_start + 7; h -= 8) {
h_vec.u64 = *(sz_u64_t const *)(h - 7);
sz_u64_t match_indicators = sz_u64_each_byte_equal(h_vec.u64, n_vec.u64);
if (match_indicators != 0) return h - sz_u64_clz(match_indicators) / 8;
match_vec = _sz_u64_each_byte_equal(h_vec, n_vec);
if (match_vec.u64) return h - sz_u64_clz(match_vec.u64) / 8;
}

for (; h >= h_start; --h)
if (*h == *needle) return h;
if (*h == *n) return h;
return NULL;
}

/**
* @brief 2Byte-level equality comparison between two 64-bit integers.
* @return 64-bit integer, where every top bit in each 2byte signifies a match.
*/
SZ_INTERNAL sz_u64_t sz_u64_each_2byte_equal(sz_u64_t a, sz_u64_t b) {
sz_u64_t match_indicators = ~(a ^ b);
SZ_INTERNAL sz_u64_vec_t _sz_u64_each_2byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) {
sz_u64_vec_t vec;
vec.u64 = ~(a.u64 ^ b.u64);
// The match is valid, if every bit within each 2byte is set.
// For that take the bottom 15 bits of each 2byte, add one to them,
// and if this sets the top bit to one, then all the 15 bits are ones as well.
match_indicators = ((match_indicators & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) &
((match_indicators & 0x8000800080008000ull));
return match_indicators;
vec.u64 = ((vec.u64 & 0x7FFF7FFF7FFF7FFFull) + 0x0001000100010001ull) & ((vec.u64 & 0x8000800080008000ull));
return vec;
}

/**
* @brief Find the first occurrence of a @b two-character needle in an arbitrary length haystack.
* This implementation uses hardware-agnostic SWAR technique, to process 8 characters at a time.
* This implementation uses hardware-agnostic SWAR technique, to process 8 offsets at a time.
*/
SZ_INTERNAL sz_cptr_t sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {

sz_cptr_t const h_end = h + h_length;
SZ_INTERNAL sz_cptr_t _sz_find_2byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {

// This is an internal method, and the haystack is guaranteed to be at least 2 bytes long.
sz_assert(h_length >= 2 && "The haystack is too short.");
sz_cptr_t const h_end = h + h_length;

#if !SZ_USE_MISALIGNED_LOADS
// Process the misaligned head, to void UB on unaligned 64-bit loads.
for (; ((sz_size_t)h & 7ull) && h < h_end; ++h)
if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h;
#endif

// This code simulates hyper-scalar execution, analyzing 7 offsets at a time.
sz_u64_vec_t h_vec, n_vec, matches_odd_vec, matches_even_vec;
sz_u64_vec_t h_even_vec, h_odd_vec, n_vec, matches_even_vec, matches_odd_vec;
n_vec.u64 = 0;
n_vec.u8s[0] = n[0];
n_vec.u8s[1] = n[1];
n_vec.u64 *= 0x0001000100010001ull;
n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1];
n_vec.u64 *= 0x0001000100010001ull; // broadcast

for (; h + 8 <= h_end; h += 7) {
h_vec = sz_u64_load(h);
matches_even_vec.u64 = sz_u64_each_2byte_equal(h_vec.u64, n_vec.u64);
matches_odd_vec.u64 = sz_u64_each_2byte_equal(h_vec.u64 >> 8, n_vec.u64);
// This code simulates hyper-scalar execution, analyzing 8 offsets at a time.
for (; h + 9 <= h_end; h += 8) {
h_even_vec.u64 = *(sz_u64_t *)h;
h_odd_vec.u64 = (h_even_vec.u64 >> 8) | (*(sz_u64_t *)&h[8] << 56);
matches_even_vec = _sz_u64_each_2byte_equal(h_even_vec, n_vec);
matches_odd_vec = _sz_u64_each_2byte_equal(h_odd_vec, n_vec);

if (matches_even_vec.u64 + matches_odd_vec.u64) {
sz_u64_t match_indicators = (matches_even_vec.u64 >> 8) | (matches_odd_vec.u64);
matches_even_vec.u64 >>= 8;
sz_u64_t match_indicators = matches_even_vec.u64 | matches_odd_vec.u64;
return h + sz_u64_ctz(match_indicators) / 8;
}
}

for (; h + 2 <= h_end; ++h)
if (h[0] == n[0] && h[1] == n[1]) return h;
if ((h[0] == n[0]) + (h[1] == n[1]) == 2) return h;
return NULL;
}

/**
* @brief 4Byte-level equality comparison between two 64-bit integers.
* @return 64-bit integer, where every top bit in each 4byte signifies a match.
*/
SZ_INTERNAL sz_u64_vec_t _sz_u64_each_4byte_equal(sz_u64_vec_t a, sz_u64_vec_t b) {
sz_u64_vec_t vec;
vec.u64 = ~(a.u64 ^ b.u64);
// The match is valid, if every bit within each 4byte is set.
// For that take the bottom 31 bits of each 4byte, add one to them,
// and if this sets the top bit to one, then all the 31 bits are ones as well.
vec.u64 = ((vec.u64 & 0x7FFFFFFF7FFFFFFFull) + 0x0000000100000001ull) & ((vec.u64 & 0x8000000080000000ull));
return vec;
}

/**
* @brief Find the first occurrence of a @b four-character needle in an arbitrary length haystack.
* This implementation uses hardware-agnostic SWAR technique, to process 8 offsets at a time.
*/
SZ_INTERNAL sz_cptr_t _sz_find_4byte_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {

// This is an internal method, and the haystack is guaranteed to be at least 4 bytes long.
sz_assert(h_length >= 4 && "The haystack is too short.");
sz_cptr_t const h_end = h + h_length;

#if !SZ_USE_MISALIGNED_LOADS
// Process the misaligned head, to void UB on unaligned 64-bit loads.
for (; ((sz_size_t)h & 7ull) && h < h_end; ++h)
if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h;
#endif

sz_u64_vec_t h0_vec, h1_vec, h2_vec, h3_vec, n_vec, matches0_vec, matches1_vec, matches2_vec, matches3_vec;
n_vec.u64 = 0;
n_vec.u8s[0] = n[0], n_vec.u8s[1] = n[1], n_vec.u8s[2] = n[2], n_vec.u8s[3] = n[3];
n_vec.u64 *= 0x0000000100000001ull; // broadcast

// This code simulates hyper-scalar execution, analyzing 8 offsets at a time using four 64-bit words.
// We load the subsequent word at onceto minimize the data dependency.
sz_u64_t h_page_current, h_page_next;
for (; h + 16 <= h_end; h += 8) {
h_page_current = *(sz_u64_t *)h;
h_page_next = *(sz_u64_t *)(h + 8);
h0_vec.u64 = (h_page_current);
h1_vec.u64 = (h_page_current >> 8) | (h_page_next << 56);
h2_vec.u64 = (h_page_current >> 16) | (h_page_next << 48);
h3_vec.u64 = (h_page_current >> 24) | (h_page_next << 40);
matches0_vec = _sz_u64_each_4byte_equal(h0_vec, n_vec);
matches1_vec = _sz_u64_each_4byte_equal(h1_vec, n_vec);
matches2_vec = _sz_u64_each_4byte_equal(h2_vec, n_vec);
matches3_vec = _sz_u64_each_4byte_equal(h3_vec, n_vec);

if (matches0_vec.u64 + matches1_vec.u64 + matches2_vec.u64 + matches3_vec.u64) {
matches0_vec.u64 >>= 24;
matches1_vec.u64 >>= 16;
matches2_vec.u64 >>= 8;
sz_u64_t match_indicators = matches0_vec.u64 | matches1_vec.u64 | matches2_vec.u64 | matches3_vec.u64;
return h + sz_u64_ctz(match_indicators) / 8;
}
}

for (; h + 4 <= h_end; ++h)
if ((h[0] == n[0]) + (h[1] == n[1]) + (h[2] == n[2]) + (h[3] == n[3]) == 4) return h;
return NULL;
}

Expand Down Expand Up @@ -1773,7 +1845,9 @@ SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n,
sz_find_t backends[] = {
// For very short strings brute-force SWAR makes sense.
(sz_find_t)sz_find_byte_serial,
(sz_find_t)sz_find_2byte_serial,
(sz_find_t)_sz_find_2byte_serial,
(sz_find_t)_sz_find_bitap_upto_8bytes_serial,
(sz_find_t)_sz_find_4byte_serial,
// For needle lengths up to 64, use the Bitap algorithm variation for exact search.
(sz_find_t)_sz_find_bitap_upto_8bytes_serial,
(sz_find_t)_sz_find_bitap_upto_16bytes_serial,
Expand All @@ -1786,9 +1860,9 @@ SZ_PUBLIC sz_cptr_t sz_find_serial(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n,

return backends[
// For very short strings brute-force SWAR makes sense.
(n_length > 1) +
(n_length > 1) + (n_length > 2) + (n_length > 3) +
// For needle lengths up to 64, use the Bitap algorithm variation for exact search.
(n_length > 2) + (n_length > 8) + (n_length > 16) + (n_length > 32) +
(n_length > 4) + (n_length > 8) + (n_length > 16) + (n_length > 32) +
// For longer needles - use skip tables.
(n_length > 64) + (n_length > 256)](h, h_length, n, n_length);
}
Expand Down
20 changes: 19 additions & 1 deletion scripts/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,15 @@ static void test_api_readonly() {
assert_throws(str("hello world").substr(-1, 5), std::out_of_range); // -1 casts to unsigned without any warnings...
assert(str("hello world").substr(0, -1) == "hello world"); // -1 casts to unsigned without any warnings...

// Substring and character search in normal and reverse directions.
// Character search in normal and reverse directions.
assert(str("hello").find('e') == 1);
assert(str("hello").find('e', 1) == 1);
assert(str("hello").find('e', 2) == str::npos);
assert(str("hello").rfind('l') == 3);
assert(str("hello").rfind('l', 2) == 2);
assert(str("hello").rfind('l', 1) == str::npos);

// Substring search in normal and reverse directions.
assert(str("hello").find("ell") == 1);
assert(str("hello").find("ell", 1) == 1);
assert(str("hello").find("ell", 2) == str::npos);
Expand All @@ -207,11 +215,21 @@ static void test_api_readonly() {
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind('x') == 23); // last byte
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind('X') == 49); // last byte

assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find("xy") == 23); // first match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find("XY") == 49); // first match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind("xy") == 23); // last match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind("XY") == 49); // last match

assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find("xyz") == 23); // first match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find("XYZ") == 49); // first match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind("xyz") == 23); // last match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind("XYZ") == 49); // last match

assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find("xyzA") == 23); // first match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find("XYZ0") == 49); // first match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind("xyzA") == 23); // last match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind("XYZ0") == 49); // last match

assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find_first_of("xyz") == 23); // sets
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find_first_of("XYZ") == 49); // sets
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find_last_of("xyz") == 25); // sets
Expand Down

0 comments on commit 8d0bca6

Please sign in to comment.