Skip to content

Commit

Permalink
Merge branch 'main-dev-neon' into main-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ashvardanian committed Jan 16, 2024
2 parents 07b669f + b6887ed commit 4cc0f5f
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 7 deletions.
171 changes: 167 additions & 4 deletions include/stringzilla/stringzilla.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@

#ifndef SZ_USE_ARM_NEON
#ifdef __ARM_NEON
#define SZ_USE_ARM_NEON 0
#define SZ_USE_ARM_NEON 1
#else
#define SZ_USE_ARM_NEON 0
#endif
Expand Down Expand Up @@ -679,6 +679,9 @@ SZ_PUBLIC sz_cptr_t sz_find_byte_serial(sz_cptr_t haystack, sz_size_t h_length,
/** @copydoc sz_find_byte */
SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);

/** @copydoc sz_find_byte */
SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);

/**
* @brief Locates last matching byte in a string. Equivalent to `memrchr(haystack, *needle, h_length)` in LibC.
*
Expand All @@ -698,6 +701,9 @@ SZ_PUBLIC sz_cptr_t sz_find_last_byte_serial(sz_cptr_t haystack, sz_size_t h_len
/** @copydoc sz_find_last_byte */
SZ_PUBLIC sz_cptr_t sz_find_last_byte_avx512(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);

/** @copydoc sz_find_last_byte */
SZ_PUBLIC sz_cptr_t sz_find_last_byte_neon(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle);

/**
* @brief Locates first matching substring.
* Equivalent to `memmem(haystack, h_length, needle, n_length)` in LibC.
Expand Down Expand Up @@ -757,6 +763,9 @@ SZ_PUBLIC sz_cptr_t sz_find_from_set_serial(sz_cptr_t text, sz_size_t length, sz
/** @copydoc sz_find_from_set */
SZ_PUBLIC sz_cptr_t sz_find_from_set_avx512(sz_cptr_t text, sz_size_t length, sz_u8_set_t const *set);

/** @copydoc sz_find_from_set */
SZ_PUBLIC sz_cptr_t sz_find_from_set_neon(sz_cptr_t text, sz_size_t length, sz_u8_set_t const *set);

/**
* @brief Finds the last character present from the ::set, present in ::text.
* Equivalent to `strspn(text, accepted)` and `strcspn(text, rejected)` in LibC.
Expand All @@ -780,6 +789,9 @@ SZ_PUBLIC sz_cptr_t sz_find_last_from_set_serial(sz_cptr_t text, sz_size_t lengt
/** @copydoc sz_find_last_from_set */
SZ_PUBLIC sz_cptr_t sz_find_last_from_set_avx512(sz_cptr_t text, sz_size_t length, sz_u8_set_t const *set);

/** @copydoc sz_find_last_from_set */
SZ_PUBLIC sz_cptr_t sz_find_last_from_set_neon(sz_cptr_t text, sz_size_t length, sz_u8_set_t const *set);

#pragma endregion

#pragma region String Similarity Measures
Expand Down Expand Up @@ -2721,9 +2733,6 @@ SZ_PUBLIC void sz_move_avx512(sz_ptr_t target, sz_cptr_t source, sz_size_t lengt
}
}

/**
* @brief Variation of AVX-512 exact search for patterns up to 1 bytes included.
*/
SZ_PUBLIC sz_cptr_t sz_find_byte_avx512(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
__mmask64 mask;
sz_u512_vec_t h_vec, n_vec;
Expand Down Expand Up @@ -3283,6 +3292,156 @@ SZ_PUBLIC sz_size_t sz_edit_distance_avx512( //

#pragma endregion

/* @brief Implementation of the string search algorithms using the Arm NEON instruction set, available on 64-bit
* Arm processors. Implements: {substring search, character search, character set search} x {forward, reverse}.
*/
#pragma region ARM NEON

#if SZ_USE_ARM_NEON
#include <arm_neon.h>

/**
* @brief Helper structure to simplify work with 64-bit words.
*/
typedef union sz_u128_vec_t {
uint8x16_t u8x16;
uint32x4_t u32x4;
sz_u64_t u64s[2];
sz_u32_t u32s[4];
sz_u16_t u16s[8];
sz_u8_t u8s[16];
} sz_u128_vec_t;

SZ_PUBLIC sz_cptr_t sz_find_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
sz_u8_t offsets[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
sz_u128_vec_t h_vec, n_vec, offsets_vec, matches_vec;
n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n);
offsets_vec.u8x16 = vld1q_u8(offsets);

while (h_length >= 16) {
h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h);
matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16);
// In Arm NEON we don't have a `movemask` to combine it with `ctz` and get the offset of the match.
// But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting)
// the vector with a relative offsets array.
if (vmaxvq_u8(matches_vec.u8x16)) {
matches_vec.u8x16 = vbslq_u8(matches_vec.u8x16, offsets_vec.u8x16, vdupq_n_u8(0xFF));
return h + vminvq_u8(matches_vec.u8x16);
}
h += 16, h_length -= 16;
}

return sz_find_byte_serial(h, h_length, n);
}

SZ_PUBLIC sz_cptr_t sz_find_last_byte_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n) {
sz_u8_t offsets[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};

sz_u128_vec_t h_vec, n_vec, offsets_vec, matches_vec;
n_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)n);
offsets_vec.u8x16 = vld1q_u8(offsets);

while (h_length >= 16) {
h_vec.u8x16 = vld1q_u8((sz_u8_t const *)h + h_length - 16);
matches_vec.u8x16 = vceqq_u8(h_vec.u8x16, n_vec.u8x16);
// In Arm NEON we don't have a `movemask` to combine it with `clz` and get the offset of the match.
// But assuming the `vmaxvq` is cheap, we can use it to find the first match, by blending (bitwise selecting)
// the vector with a relative offsets array.
if (vmaxvq_u8(matches_vec.u8x16)) {
matches_vec.u8x16 = vbslq_u8(matches_vec.u8x16, offsets_vec.u8x16, vdupq_n_u8(0));
return h + h_length - 16 + vmaxvq_u8(matches_vec.u8x16);
}
h_length -= 16;
}

return sz_find_last_byte_serial(h, h_length, n);
}

SZ_PUBLIC sz_cptr_t sz_find_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
if (n_length == 1) return sz_find_byte_neon(h, h_length, n);

// Will contain 4 bits per character.
sz_u64_t matches;
sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec;
n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]);
n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[n_length / 2]);
n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[n_length - 1]);

for (; h_length >= n_length + 16; h += 16, h_length -= 16) {
h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h));
h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + n_length / 2));
h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + n_length - 1));
matches_vec.u8x16 = vandq_u8( //
vandq_u8( //
vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), //
vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)),
vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16));
if (vmaxvq_u8(matches_vec.u8x16)) {
// Use `vshrn` to produce a bitmask, similar to `movemask` in SSE.
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
matches = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(matches_vec.u8x16), 4)), 0) &
0x8888888888888888ull;
while (matches) {
int potential_offset = sz_u64_ctz(matches) / 4;
if (sz_equal(h + potential_offset + 1, n + 1, n_length - 2)) return h + potential_offset;
matches &= matches - 1;
}
}
}

return sz_find_serial(h, h_length, n, n_length);
}

SZ_PUBLIC sz_cptr_t sz_find_last_neon(sz_cptr_t h, sz_size_t h_length, sz_cptr_t n, sz_size_t n_length) {
if (n_length == 1) return sz_find_last_byte_neon(h, h_length, n);

// Will contain 4 bits per character.
sz_u64_t matches;
sz_u128_vec_t h_first_vec, h_mid_vec, h_last_vec, n_first_vec, n_mid_vec, n_last_vec, matches_vec;
n_first_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[0]);
n_mid_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[n_length / 2]);
n_last_vec.u8x16 = vld1q_dup_u8((sz_u8_t const *)&n[n_length - 1]);

for (; h_length >= n_length + 16; h_length -= 16) {
h_first_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + h_length - n_length - 16 + 1));
h_mid_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + h_length - n_length - 16 + 1 + n_length / 2));
h_last_vec.u8x16 = vld1q_u8((sz_u8_t const *)(h + h_length - 16));
matches_vec.u8x16 = vandq_u8( //
vandq_u8( //
vceqq_u8(h_first_vec.u8x16, n_first_vec.u8x16), //
vceqq_u8(h_mid_vec.u8x16, n_mid_vec.u8x16)),
vceqq_u8(h_last_vec.u8x16, n_last_vec.u8x16));
if (vmaxvq_u8(matches_vec.u8x16)) {
// Use `vshrn` to produce a bitmask, similar to `movemask` in SSE.
// https://community.arm.com/arm-community-blogs/b/infrastructure-solutions-blog/posts/porting-x86-vector-bitmask-optimizations-to-arm-neon
matches = vget_lane_u64(vreinterpret_u64_u8(vshrn_n_u16(vreinterpretq_u16_u8(matches_vec.u8x16), 4)), 0) &
0x8888888888888888ull;
while (matches) {
int potential_offset = sz_u64_clz(matches) / 4;
if (sz_equal(h + h_length - n_length - potential_offset + 1, n + 1, n_length - 2))
return h + h_length - n_length - potential_offset;
sz_assert((matches & (1ull << (63 - potential_offset * 4))) != 0 &&
"The bit must be set before we squash it");
matches &= ~(1ull << (63 - potential_offset * 4));
}
}
}

return sz_find_last_serial(h, h_length, n, n_length);
}

SZ_PUBLIC sz_cptr_t sz_find_from_set_neon(sz_cptr_t h, sz_size_t h_length, sz_u8_set_t const *set) {
return sz_find_from_set_serial(h, h_length, set);
}

SZ_PUBLIC sz_cptr_t sz_find_last_from_set_neon(sz_cptr_t h, sz_size_t h_length, sz_u8_set_t const *set) {
return sz_find_last_from_set_serial(h, h_length, set);
}

#endif // Arm Neon

#pragma endregion

/*
* @brief Pick the right implementation for the string search algorithms.
*/
Expand Down Expand Up @@ -3333,6 +3492,8 @@ SZ_PUBLIC sz_ordering_t sz_order(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, s
SZ_PUBLIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) {
#if SZ_USE_X86_AVX512
return sz_find_byte_avx512(haystack, h_length, needle);
#elif SZ_USE_ARM_NEON
return sz_find_byte_neon(haystack, h_length, needle);
#else
return sz_find_byte_serial(haystack, h_length, needle);
#endif
Expand All @@ -3341,6 +3502,8 @@ SZ_PUBLIC sz_cptr_t sz_find_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr
SZ_PUBLIC sz_cptr_t sz_find_last_byte(sz_cptr_t haystack, sz_size_t h_length, sz_cptr_t needle) {
#if SZ_USE_X86_AVX512
return sz_find_last_byte_avx512(haystack, h_length, needle);
#elif SZ_USE_ARM_NEON
return sz_find_last_byte_neon(haystack, h_length, needle);
#else
return sz_find_last_byte_serial(haystack, h_length, needle);
#endif
Expand Down
34 changes: 31 additions & 3 deletions scripts/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ static void test_api_readonly() {

// Constructors.
assert(str().empty()); // Test default constructor
assert(str().size() == 0); // Test default constructor
assert(str("").empty()); // Test default constructor
assert(str("").size() == 0); // Test default constructor
assert(str("hello").size() == 5); // Test constructor with c-string
assert(str("hello", 4) == "hell"); // Construct from substring

Expand All @@ -166,7 +169,7 @@ static void test_api_readonly() {
assert(*str("rbegin").rbegin() == 'n' && *str("crbegin").crbegin() == 'n');
assert(str("size").size() == 4 && str("length").length() == 6);

// Slices... out-of-bounds exceptions are asymetric!
// Slices... out-of-bounds exceptions are asymmetric!
// Moreover, `std::string` has no `remove_prefix` and `remove_suffix` methods.
// assert_scoped(str s = "hello", s.remove_prefix(1), s == "ello");
// assert_scoped(str s = "hello", s.remove_suffix(1), s == "hell");
Expand All @@ -188,7 +191,7 @@ static void test_api_readonly() {
assert(str("hello").rfind("l", 2) == 2);
assert(str("hello").rfind("l", 1) == str::npos);

// ! `rfind` and `find_last_of` are not consitent in meaning of their arguments.
// ! `rfind` and `find_last_of` are not consistent in meaning of their arguments.
assert(str("hello").find_first_of("le") == 1);
assert(str("hello").find_first_of("le", 1) == 1);
assert(str("hello").find_last_of("le") == 3);
Expand All @@ -198,6 +201,22 @@ static void test_api_readonly() {
assert(str("hello").find_last_not_of("hel") == 4);
assert(str("hello").find_last_not_of("hel", 4) == 4);

// Try longer strings to enforce SIMD.
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find('x') == 23); // first byte
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find('X') == 49); // first byte
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind('x') == 23); // last byte
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind('X') == 49); // last byte

assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find("xyz") == 23); // first match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find("XYZ") == 49); // first match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind("xyz") == 23); // last match
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").rfind("XYZ") == 49); // last match

assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find_first_of("xyz") == 23); // sets
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find_first_of("XYZ") == 49); // sets
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find_last_of("xyz") == 25); // sets
assert(str("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789+-").find_last_of("XYZ") == 51); // sets

// Comparisons.
assert(str("a") != str("b"));
assert(str("a") < str("b"));
Expand Down Expand Up @@ -270,7 +289,7 @@ static void test_api_readonly() {
#endif

#if SZ_DETECT_CPP_23 && __cpp_lib_string_contains
// Checking basic substring presense.
// Checking basic substring presence.
assert(str("hello").contains(str("ell")) == true);
assert(str("hello").contains(str("oll")) == false);
assert(str("hello").contains('l') == true);
Expand Down Expand Up @@ -318,6 +337,9 @@ static void test_api_mutable() {

// Constructors.
assert(str().empty()); // Test default constructor
assert(str().size() == 0); // Test default constructor
assert(str("").empty()); // Test default constructor
assert(str("").size() == 0); // Test default constructor
assert(str("hello").size() == 5); // Test constructor with c-string
assert(str("hello", 4) == "hell"); // Construct from substring
assert(str(5, 'a') == "aaaaa"); // Construct with count and character
Expand Down Expand Up @@ -956,6 +978,12 @@ static void test_search_with_misaligned_repetitions() {
test_search_with_misaligned_repetitions("ab", "ba");
test_search_with_misaligned_repetitions("abc", "ca");
test_search_with_misaligned_repetitions("abcd", "da");

// Examples targeted exactly against the Raita heuristic,
// which matches the first, the last, and the middle characters with SIMD.
test_search_with_misaligned_repetitions("aaabbccc", "aaabbccc");
test_search_with_misaligned_repetitions("axabbcxc", "aaabbccc");
test_search_with_misaligned_repetitions("axabbcxcaaabbccc", "aaabbccc");
}

#endif
Expand Down

0 comments on commit 4cc0f5f

Please sign in to comment.