diff --git a/CMakeLists.txt b/CMakeLists.txt index fb15ad67..4b66cc3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,7 @@ endif() # Function to set compiler-specific flags function(set_compiler_flags target) + target_include_directories(${target} PRIVATE scripts) target_link_libraries(${target} PRIVATE ${STRINGZILLA_TARGET_NAME}) set_target_properties(${target} PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) @@ -104,14 +105,30 @@ function(set_compiler_flags target) endfunction() if(${STRINGZILLA_BUILD_BENCHMARK}) - add_executable(stringzilla_search_bench scripts/search_bench.cpp) - set_compiler_flags(stringzilla_search_bench) - add_test(NAME stringzilla_search_bench COMMAND stringzilla_search_bench) + add_executable(stringzilla_bench_search scripts/bench_search.cpp) + set_compiler_flags(stringzilla_bench_search) + add_test(NAME stringzilla_bench_search COMMAND stringzilla_bench_search) + + add_executable(stringzilla_bench_similarity scripts/bench_similarity.cpp) + set_compiler_flags(stringzilla_bench_similarity) + add_test(NAME stringzilla_bench_similarity COMMAND stringzilla_bench_similarity) + + add_executable(stringzilla_bench_sort scripts/bench_sort.cpp) + set_compiler_flags(stringzilla_bench_sort) + add_test(NAME stringzilla_bench_sort COMMAND stringzilla_bench_sort) + + add_executable(stringzilla_bench_token scripts/bench_token.cpp) + set_compiler_flags(stringzilla_bench_token) + add_test(NAME stringzilla_bench_token COMMAND stringzilla_bench_token) + + add_executable(stringzilla_bench_container scripts/bench_container.cpp) + set_compiler_flags(stringzilla_bench_container) + add_test(NAME stringzilla_bench_container COMMAND stringzilla_bench_container) endif() if(${STRINGZILLA_BUILD_TEST}) # Test target - add_executable(stringzilla_search_test scripts/search_test.cpp) - set_compiler_flags(stringzilla_search_test) - add_test(NAME stringzilla_search_test COMMAND stringzilla_search_test) + add_executable(stringzilla_test scripts/test.cpp) + set_compiler_flags(stringzilla_test) + add_test(NAME stringzilla_test COMMAND stringzilla_test) endif() diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6bb02de4..fb75eed1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,17 +11,36 @@ The project is split into the following parts: - `include/stringzilla/stringzilla.h` - single-header C implementation. - `include/stringzilla/stringzilla.hpp` - single-header C++ wrapper. -- `python/**` - Python bindings. -- `javascript/**` - JavaScript bindings. -- `scripts/**` - Scripts for benchmarking and testing. +- `python/*` - Python bindings. +- `javascript/*` - JavaScript bindings. +- `scripts/*` - Scripts for benchmarking and testing. -The scripts name convention is as follows: `_.`. -An example would be, `search_bench.cpp` or `similarity_fuzz.py`. -The nature of the script can be: +For minimal test coverage, check the following scripts: -- `bench` - bounded in time benchmarking, generally on user-provided data. -- `fuzz` - unbounded in time fuzzing, generally on randomly generated data. -- `test` - unit tests. +- `test.cpp` - tests C++ API (not underlying C) against STL. +- `test.py` - tests Python API against native strings. +- `test.js`. + +At the C++ level all benchmarks also validate the results against the STL baseline, serving as tests on real-world data. +They have the broadest coverage of the library, and are the most important to keep up-to-date: + +- `bench_token.cpp` - token-level ops, like hashing, ordering, equality checks. +- `bench_search.cpp` - bidirectional substring search, both exact and fuzzy. +- `bench_similarity.cpp` - benchmark all edit distance backends. +- `bench_sort.cpp` - sorting, partitioning, merging. +- `bench_container.cpp` - STL containers with different string keys. + +The role of Python benchmarks is less to provide absolute number, but to compare against popular tools in the Python ecosystem. + +- `bench_search.py` - compares against native Python `str`. +- `bench_sort.py` - compares against `pandas`. +- `bench_similarity.py` - compares against `jellyfish`, `editdistance`, etc. + +For presentation purposes, we also + +## IDE Integrations + +The project is developed in VS Code, and comes with debugger launchers in `.vscode/launch.json`. ## Contributing in C++ and C @@ -40,7 +59,7 @@ Using modern syntax, this is how you build and run the test suite: ```bash cmake -DSTRINGZILLA_BUILD_TEST=1 -B ./build_debug cmake --build ./build_debug --config Debug # Which will produce the following targets: -./build_debug/search_test # Unit test for substring search +./build_debug/stringzilla_test # Unit test for the entire library ``` For benchmarks, you can use the following commands: @@ -48,8 +67,8 @@ For benchmarks, you can use the following commands: ```bash cmake -DSTRINGZILLA_BUILD_BENCHMARK=1 -B ./build_release cmake --build ./build_release --config Release # Which will produce the following targets: -./build_release/search_bench # Benchmark for substring search -./build_release/sort_bench # Benchmark for sorting arrays of strings +./build_release/stringzilla_bench_search # Benchmark for substring search +./build_release/stringzilla_bench_sort # Benchmark for sorting arrays of strings ``` Running on modern hardware, you may want to compile the code for older generations to compare the relative performance. @@ -67,9 +86,9 @@ cmake -DCMAKE_BUILD_TYPE=Release -DSTRINGZILLA_BUILD_BENCHMARK=1 \ -DCMAKE_CXX_FLAGS="-march=sapphirerapids" -DCMAKE_C_FLAGS="-march=sapphirerapids" \ -B ./build_release/sapphirerapids && cmake --build build_release/sapphirerapids --config Release -./build_release/sandybridge/stringzilla_search_bench -./build_release/haswell/stringzilla_search_bench -./build_release/sapphirerapids/stringzilla_search_bench +./build_release/sandybridge/stringzilla_bench_search +./build_release/haswell/stringzilla_bench_search +./build_release/sapphirerapids/stringzilla_bench_search ``` Alternatively, you may want to compare the performance of the code compiled with different compilers. @@ -95,8 +114,8 @@ pip install -e . # To build locally from source For testing we use PyTest, which may not be installed on your system. ```bash -pip install pytest # To install PyTest -pytest scripts/ -s -x # To run the test suite +pip install pytest # To install PyTest +pytest scripts/unit_test.py -s -x # To run the test suite ``` For fuzzing we love the ability to call the native C implementation from Python bypassing the binding layer. @@ -110,8 +129,8 @@ python scripts/similarity_fuzz.py # To run the fuzzing script For benchmarking, the following scripts are provided. ```sh -python scripts/search_bench.py --haystack_path "your file" --needle "your pattern" # real data -python scripts/search_bench.py --haystack_pattern "abcd" --haystack_length 1e9 --needle "abce" # synthetic data +python scripts/bench_search.py --haystack_path "your file" --needle "your pattern" # real data +python scripts/bench_search.py --haystack_pattern "abcd" --haystack_length 1e9 --needle "abce" # synthetic data python scripts/similarity_bench.py --text_path "your file" # edit ditance computations ``` @@ -132,6 +151,7 @@ Future development plans include: - [x] [Reverse-order operations](https://github.com/ashvardanian/StringZilla/issues/12). - [ ] [Faster string sorting algorithm](https://github.com/ashvardanian/StringZilla/issues/45). - [ ] [Splitting with multiple separators at once](https://github.com/ashvardanian/StringZilla/issues/29). +- [ ] Add `.pyi` interface fior Python. - [ ] Arm NEON backend. - [ ] Bindings for Rust. - [ ] Arm SVE backend. diff --git a/README.md b/README.md index a225294d..63cc9a78 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ Aside from exact search, the library also accelerates fuzzy search, edit distanc - Code in C? Replace LibC's `` with C 99 `` - [_more_](#quick-start-c-🛠️) - Code in C++? Replace STL's `` with C++ 11 `` - [_more_](#quick-start-cpp-🛠️) - Code in Python? Upgrade your `str` to faster `Str` - [_more_](#quick-start-python-🐍) +- Code in other languages? Let us know! __Features:__ @@ -131,7 +132,7 @@ import stringzilla as sz contains: bool = sz.contains("haystack", "needle", start=0, end=9223372036854775807) offset: int = sz.find("haystack", "needle", start=0, end=9223372036854775807) count: int = sz.count("haystack", "needle", start=0, end=9223372036854775807, allowoverlap=False) -levenshtein: int = sz.levenshtein("needle", "nidl") +edit_distance: int = sz.edit_distance("needle", "nidl") ``` ## Quick Start: C/C++ 🛠️ @@ -202,6 +203,19 @@ haystack.contains(needle) == true; // STL has this only from C++ 23 onwards haystack.compare(needle) == 1; // Or `haystack <=> needle` in C++ 20 and beyond ``` +StringZilla also provides string literals for automatic type resolution, [similar to STL][stl-literal]: + +```cpp +using sz::literals::operator""_sz; +using std::literals::operator""sv; + +auto a = "some string"; // char const * +auto b = "some string"sv; // std::string_view +auto b = "some string"_sz; // sz::string_view +``` + +[stl-literal]: https://en.cppreference.com/w/cpp/string/basic_string_view/operator%22%22sv + ### Memory Ownership and Small String Optimization Most operations in StringZilla don't assume any memory ownership. @@ -334,6 +348,73 @@ Debugging pointer offsets is not a pleasant exercise, so keep the following func - `haystack.[r]split_all(character_set(""))` For $N$ matches the split functions will report $N+1$ matches, potentially including empty strings. +Ranges have a few convinience methods as well: + +```cpp +range.size(); // -> std::size_t +range.empty(); // -> bool +range.template to>(); +range.template to>(); +``` + +### TODO: STL Containers with String Keys + +The C++ Standard Templates Library provides several associative containers, often used with string keys. + +```cpp +std::map> sorted_words; +std::unordered_map, std::equal_to> words; +``` + +The performance of those containers is often limited by the performance of the string keys, especially on reads. +StringZilla can be used to accelerate containers with `std::string` keys, by overriding the default comparator and hash functions. + +```cpp +std::map sorted_words; +std::unordered_map words; +``` + +Alternatively, a better approach would be to use the `sz::string` class as a key. +The right hash function and comparator would be automatically selected and the performance gains would be more noticeable if the keys are short. + +```cpp +std::map sorted_words; +std::unordered_map words; +``` + +### TODO: Concatenating Strings + +Ansother common string operation is concatenation. +The STL provides `std::string::operator+` and `std::string::append`, but those are not the most efficient, if multiple invocations are performed. + +```cpp +std::string name, domain, tld; +auto email = name + "@" + domain + "." + tld; // 4 allocations +``` + +The efficient approach would be to pre-allocate the memory and copy the strings into it. + +```cpp +std::string email; +email.reserve(name.size() + domain.size() + tld.size() + 2); +email.append(name), email.append("@"), email.append(domain), email.append("."), email.append(tld); +``` + +That's mouthful and error-prone. +StringZilla provides a more convenient `concat` function, which takes a variadic number of arguments. + +```cpp +auto email = sz::concat(name, "@", domain, ".", tld); +``` + +Moreover, if the first or second argument of the expression is a StringZilla string, the concatenation can be poerformed lazily using the same `operator+` syntax. +That behavior is disabled for compatibility by default, but can be enabled by defining `SZ_LAZY_CONCAT` macro. + +```cpp +sz::string name, domain, tld; +auto email_expression = name + "@" + domain + "." + tld; // 0 allocations +sz::string email = name + "@" + domain + "." + tld; // 1 allocations +``` ### Debugging @@ -342,6 +423,12 @@ That behavior is controllable for both C and C++ interfaces via the `STRINGZILLA [faq-sso]: https://cpp-optimizations.netlify.app/small_strings/ +## Algorithms 📚 + +### Hashing + +### Substring Search + ## Contributing 👾 Please check out the [contributing guide](CONTRIBUTING.md) for more details on how to setup the development environment and contribute to this project. diff --git a/include/stringzilla/stringzilla.h b/include/stringzilla/stringzilla.h index 3ba44dcc..d6ce2509 100644 --- a/include/stringzilla/stringzilla.h +++ b/include/stringzilla/stringzilla.h @@ -104,9 +104,9 @@ * @brief Annotation for the public API symbols. */ #if defined(_WIN32) || defined(__CYGWIN__) -#define SZ_PUBLIC __declspec(dllexport) inline static +#define SZ_PUBLIC inline static #elif __GNUC__ >= 4 -#define SZ_PUBLIC __attribute__((visibility("default"))) inline static +#define SZ_PUBLIC inline static #else #define SZ_PUBLIC inline static #endif @@ -717,11 +717,11 @@ SZ_PUBLIC sz_cptr_t sz_find_last_bounded_regex(sz_cptr_t haystack, sz_size_t h_l * @return Unsigned edit distance. */ SZ_PUBLIC sz_size_t sz_edit_distance(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t const *alloc); + sz_size_t bound, sz_memory_allocator_t const *alloc); /** @copydoc sz_edit_distance */ SZ_PUBLIC sz_size_t sz_edit_distance_serial(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // - sz_size_t bound, sz_memory_allocator_t const *alloc); + sz_size_t bound, sz_memory_allocator_t const *alloc); /** @copydoc sz_edit_distance */ SZ_PUBLIC sz_size_t sz_edit_distance_avx512(sz_cptr_t a, sz_size_t a_length, sz_cptr_t b, sz_size_t b_length, // diff --git a/include/stringzilla/stringzilla.hpp b/include/stringzilla/stringzilla.hpp index e3e4219e..c713ee59 100644 --- a/include/stringzilla/stringzilla.hpp +++ b/include/stringzilla/stringzilla.hpp @@ -19,6 +19,9 @@ #include #endif +#include // `assert` +#include // `std::size_t` + #include namespace ashvardanian { diff --git a/python/lib.c b/python/lib.c index 0ab40fa5..0ea3de84 100644 --- a/python/lib.c +++ b/python/lib.c @@ -1051,7 +1051,7 @@ static PyObject *Str_count(PyObject *self, PyObject *args, PyObject *kwargs) { return PyLong_FromSize_t(count); } -static PyObject *Str_levenshtein(PyObject *self, PyObject *args, PyObject *kwargs) { +static PyObject *Str_edit_distance(PyObject *self, PyObject *args, PyObject *kwargs) { int is_member = self != NULL && PyObject_TypeCheck(self, &StrType); Py_ssize_t nargs = PyTuple_Size(args); if (nargs < !is_member + 1 || nargs > !is_member + 2) { @@ -1093,7 +1093,7 @@ static PyObject *Str_levenshtein(PyObject *self, PyObject *args, PyObject *kwarg sz_memory_allocator_t reusing_allocator; reusing_allocator.allocate = &temporary_memory_allocate; reusing_allocator.free = &temporary_memory_free; - reusing_allocator.user_data = &temporary_memory; + reusing_allocator.handle = &temporary_memory; sz_size_t distance = sz_edit_distance(str1.start, str1.length, str2.start, str2.length, (sz_size_t)bound, &reusing_allocator); @@ -1469,7 +1469,7 @@ static PyMethodDef Str_methods[] = { {"splitlines", Str_splitlines, sz_method_flags_m, "Split a string by line breaks."}, {"startswith", Str_startswith, sz_method_flags_m, "Check if a string starts with a given prefix."}, {"endswith", Str_endswith, sz_method_flags_m, "Check if a string ends with a given suffix."}, - {"levenshtein", Str_levenshtein, sz_method_flags_m, "Calculate the Levenshtein distance between two strings."}, + {"edit_distance", Str_edit_distance, sz_method_flags_m, "Calculate the Levenshtein distance between two strings."}, {NULL, NULL, 0, NULL}}; static PyTypeObject StrType = { @@ -1763,7 +1763,7 @@ static PyMethodDef stringzilla_methods[] = { {"splitlines", Str_splitlines, sz_method_flags_m, "Split a string by line breaks."}, {"startswith", Str_startswith, sz_method_flags_m, "Check if a string starts with a given prefix."}, {"endswith", Str_endswith, sz_method_flags_m, "Check if a string ends with a given suffix."}, - {"levenshtein", Str_levenshtein, sz_method_flags_m, "Calculate the Levenshtein distance between two strings."}, + {"edit_distance", Str_edit_distance, sz_method_flags_m, "Calculate the Levenshtein distance between two strings."}, {NULL, NULL, 0, NULL}}; static PyModuleDef stringzilla_module = { diff --git a/scripts/bench.hpp b/scripts/bench.hpp new file mode 100644 index 00000000..b3cdc6c3 --- /dev/null +++ b/scripts/bench.hpp @@ -0,0 +1,311 @@ +/** + * @brief Helper structures for C++ benchmarks. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifdef NDEBUG // Make debugging faster +#define default_seconds_m 10 +#else +#define default_seconds_m 10 +#endif + +namespace ashvardanian { +namespace stringzilla { +namespace scripts { + +using seconds_t = double; + +struct benchmark_result_t { + std::size_t iterations = 0; + std::size_t bytes_passed = 0; + seconds_t seconds = 0; +}; + +using unary_function_t = std::function; +using binary_function_t = std::function; + +/** + * @brief Wrapper for a single execution backend. + */ +template +struct tracked_function_gt { + std::string name {""}; + function_at function {nullptr}; + bool needs_testing {false}; + + std::size_t failed_count {0}; + std::vector failed_strings {}; + benchmark_result_t results {}; + + void print() const { + char const *format; + // Now let's print in the format: + // - name, up to 20 characters + // - throughput in GB/s with up to 3 significant digits, 10 characters + // - call latency in ns with up to 1 significant digit, 10 characters + // - number of failed tests, 10 characters + // - first example of a failed test, up to 20 characters + if constexpr (std::is_same()) + format = "%-20s %10.3f GB/s %10.1f ns %10zu %s %s\n"; + else + format = "%-20s %10.3f GB/s %10.1f ns %10zu %s\n"; + std::printf(format, name.c_str(), results.bytes_passed / results.seconds / 1.e9, + results.seconds * 1e9 / results.iterations, failed_count, + failed_strings.size() ? failed_strings[0].c_str() : "", + failed_strings.size() ? failed_strings[1].c_str() : ""); + } +}; + +using tracked_unary_functions_t = std::vector>; +using tracked_binary_functions_t = std::vector>; + +/** + * @brief Stops compilers from optimizing out the expression. + * Shamelessly stolen from Google Benchmark. + */ +template +inline void do_not_optimize(value_at &&value) { + asm volatile("" : "+r"(value) : : "memory"); +} + +inline sz_string_view_t sz_string_view(std::string const &str) { return {str.data(), str.size()}; }; + +/** + * @brief Rounds the number down to the preceding power of two. + * Equivalent to `std::bit_ceil`. + */ +inline std::size_t bit_floor(std::size_t n) { + if (n == 0) return 0; + std::size_t most_siginificant_bit_position = 0; + while (n > 1) n >>= 1, most_siginificant_bit_position++; + return static_cast(1) << most_siginificant_bit_position; +} + +inline std::string read_file(std::string path) { + std::ifstream stream(path); + if (!stream.is_open()) { throw std::runtime_error("Failed to open file: " + path); } + return std::string((std::istreambuf_iterator(stream)), std::istreambuf_iterator()); +} + +/** + * @brief Splits a string into words,using newlines, tabs, and whitespaces as delimiters. + */ +inline std::vector tokenize(std::string_view str) { + std::vector words; + std::size_t start = 0; + for (std::size_t end = 0; end <= str.length(); ++end) { + if (end == str.length() || std::isspace(str[end])) { + if (start < end) words.push_back({&str[start], end - start}); + start = end + 1; + } + } + return words; +} + +struct dataset_t { + std::string text; + std::vector tokens; + + inline std::vector tokens_of_length(std::size_t n) const { + std::vector result; + for (auto const &str : tokens) + if (str.size() == n) result.push_back(str); + return result; + } +}; + +/** + * @brief Loads a dataset from a file. + */ +inline dataset_t make_dataset_from_path(std::string path) { + dataset_t data; + data.text = read_file(path); + data.text.resize(bit_floor(data.text.size())); + data.tokens = tokenize(data.text); + data.tokens.resize(bit_floor(data.tokens.size())); + +#ifdef NDEBUG // Shuffle only in release mode + std::random_device random_device; + std::mt19937 random_generator(random_device()); + std::shuffle(data.tokens.begin(), data.tokens.end(), random_generator); +#endif + + // Report some basic stats about the dataset + std::size_t mean_bytes = 0; + for (auto const &str : data.tokens) mean_bytes += str.size(); + mean_bytes /= data.tokens.size(); + std::printf("Parsed the file with %zu words of %zu mean length!\n", data.tokens.size(), mean_bytes); + + return data; +} + +/** + * @brief Loads a dataset, depending on the passed CLI arguments. + */ +inline dataset_t make_dataset(int argc, char const *argv[]) { + if (argc != 2) { throw std::runtime_error("Usage: " + std::string(argv[0]) + " "); } + return make_dataset_from_path(argv[1]); +} + +/** + * @brief Loop over all elements in a dataset in somewhat random order, benchmarking the function cost. + * @param strings Strings to loop over. Length must be a power of two. + * @param function Function to be applied to each `sz_string_view_t`. Must return the number of bytes processed. + * @return Number of seconds per iteration. + */ +template +benchmark_result_t loop_over_words(strings_at &&strings, function_at &&function, + seconds_t max_time = default_seconds_m) { + + namespace stdc = std::chrono; + using stdcc = stdc::high_resolution_clock; + stdcc::time_point t1 = stdcc::now(); + benchmark_result_t result; + std::size_t lookup_mask = bit_floor(strings.size()) - 1; + + while (true) { + // Unroll a few iterations, to avoid some for-loops overhead and minimize impact of time-tracking + { + result.bytes_passed += function(sz_string_view(strings[(++result.iterations) & lookup_mask])); + result.bytes_passed += function(sz_string_view(strings[(++result.iterations) & lookup_mask])); + result.bytes_passed += function(sz_string_view(strings[(++result.iterations) & lookup_mask])); + result.bytes_passed += function(sz_string_view(strings[(++result.iterations) & lookup_mask])); + } + + stdcc::time_point t2 = stdcc::now(); + result.seconds = stdc::duration_cast(t2 - t1).count() / 1.e9; + if (result.seconds > max_time) break; + } + + return result; +} + +/** + * @brief Loop over all elements in a dataset, benchmarking the function cost. + * @param strings Strings to loop over. Length must be a power of two. + * @param function Function to be applied to pairs of `sz_string_view_t`. + * Must return the number of bytes processed. + * @return Number of seconds per iteration. + */ +template +benchmark_result_t loop_over_pairs_of_words(strings_at &&strings, function_at &&function, + seconds_t max_time = default_seconds_m) { + + namespace stdc = std::chrono; + using stdcc = stdc::high_resolution_clock; + stdcc::time_point t1 = stdcc::now(); + benchmark_result_t result; + std::size_t lookup_mask = bit_floor(strings.size()) - 1; + + while (true) { + // Unroll a few iterations, to avoid some for-loops overhead and minimize impact of time-tracking + { + result.bytes_passed += + function(sz_string_view(strings[(++result.iterations) & lookup_mask]), + sz_string_view(strings[(result.iterations * 18446744073709551557ull) & lookup_mask])); + result.bytes_passed += + function(sz_string_view(strings[(++result.iterations) & lookup_mask]), + sz_string_view(strings[(result.iterations * 18446744073709551557ull) & lookup_mask])); + result.bytes_passed += + function(sz_string_view(strings[(++result.iterations) & lookup_mask]), + sz_string_view(strings[(result.iterations * 18446744073709551557ull) & lookup_mask])); + result.bytes_passed += + function(sz_string_view(strings[(++result.iterations) & lookup_mask]), + sz_string_view(strings[(result.iterations * 18446744073709551557ull) & lookup_mask])); + } + + stdcc::time_point t2 = stdcc::now(); + result.seconds = stdc::duration_cast(t2 - t1).count() / 1.e9; + if (result.seconds > max_time) break; + } + + return result; +} + +/** + * @brief Evaluation for unary string operations: hashing. + */ +template +void evaluate_unary_operations(strings_at &&strings, tracked_unary_functions_t &&variants) { + + for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { + auto &variant = variants[variant_idx]; + + // Tests + if (variant.function && variant.needs_testing) { + loop_over_words(strings, [&](sz_string_view_t str) { + auto baseline = variants[0].function(str); + auto result = variant.function(str); + if (result != baseline) { + ++variant.failed_count; + if (variant.failed_strings.empty()) { variant.failed_strings.push_back({str.start, str.length}); } + } + return str.length; + }); + } + + // Benchmarks + if (variant.function) { + variant.results = loop_over_words(strings, [&](sz_string_view_t str) { + do_not_optimize(variant.function(str)); + return str.length; + }); + } + + variant.print(); + } +} + +/** + * @brief Evaluation for binary string operations: equality, ordering, prefix, suffix, distance. + */ +template +void evaluate_binary_operations(strings_at &&strings, tracked_binary_functions_t &&variants) { + + for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { + auto &variant = variants[variant_idx]; + + // Tests + if (variant.function && variant.needs_testing) { + loop_over_pairs_of_words(strings, [&](sz_string_view_t str_a, sz_string_view_t str_b) { + auto baseline = variants[0].function(str_a, str_b); + auto result = variant.function(str_a, str_b); + if (result != baseline) { + ++variant.failed_count; + if (variant.failed_strings.empty()) { + variant.failed_strings.push_back({str_a.start, str_a.length}); + variant.failed_strings.push_back({str_b.start, str_b.length}); + } + } + return str_a.length + str_b.length; + }); + } + + // Benchmarks + if (variant.function) { + variant.results = loop_over_pairs_of_words(strings, [&](sz_string_view_t str_a, sz_string_view_t str_b) { + do_not_optimize(variant.function(str_a, str_b)); + return str_a.length + str_b.length; + }); + } + + variant.print(); + } +} + +} // namespace scripts +} // namespace stringzilla +} // namespace ashvardanian \ No newline at end of file diff --git a/scripts/bench_container.cpp b/scripts/bench_container.cpp new file mode 100644 index 00000000..1d9f3edc --- /dev/null +++ b/scripts/bench_container.cpp @@ -0,0 +1,24 @@ +/** + * @file bench_container.cpp + * @brief Benchmarks STL associative containers with string keys. + * + * This file is the sibling of `bench_sort.cpp`, `bench_search.cpp` and `bench_token.cpp`. + * It accepts a file with a list of words, constructs associative containers with string keys, + * using `std::string`, `std::string_view`, `sz::string_view`, and `sz::string`, and then + * evaluates the latency of lookups. + */ +#include +#include + +#include + +using namespace ashvardanian::stringzilla::scripts; + +int main(int argc, char const **argv) { + std::printf("StringZilla. Starting STL container benchmarks.\n"); + + // dataset_t dataset = make_dataset(argc, argv); + + std::printf("All benchmarks passed.\n"); + return 0; +} \ No newline at end of file diff --git a/scripts/bench_search.cpp b/scripts/bench_search.cpp new file mode 100644 index 00000000..fa3b5fc3 --- /dev/null +++ b/scripts/bench_search.cpp @@ -0,0 +1,251 @@ +/** + * @file bench_search.cpp + * @brief Benchmarks for bidirectional string search operations - exact and approximate. + * + * This file is the sibling of `bench_sort.cpp`, `bench_token.cpp` and `bench_similarity.cpp`. + * It accepts a file with a list of words, and benchmarks the search operations on them. + * Outside of present tokens also tries missing tokens. + */ +#include + +using namespace ashvardanian::stringzilla::scripts; + +tracked_binary_functions_t find_functions() { + auto wrap_sz = [](auto function) -> binary_function_t { + return binary_function_t([function](sz_string_view_t h, sz_string_view_t n) { + sz_cptr_t match = function(h.start, h.length, n.start, n.length); + return (sz_ssize_t)(match ? match - h.start : h.length); + }); + }; + tracked_binary_functions_t result = { + {"std::string_view.find", + [](sz_string_view_t h, sz_string_view_t n) { + auto h_view = std::string_view(h.start, h.length); + auto n_view = std::string_view(n.start, n.length); + auto match = h_view.find(n_view); + return (sz_ssize_t)(match == std::string_view::npos ? h.length : match); + }}, + {"sz_find_serial", wrap_sz(sz_find_serial), true}, +#if SZ_USE_X86_AVX512 + {"sz_find_avx512", wrap_sz(sz_find_avx512), true}, +#endif +#if SZ_USE_ARM_NEON + {"sz_find_neon", wrap_sz(sz_find_neon), true}, +#endif + {"strstr", + [](sz_string_view_t h, sz_string_view_t n) { + sz_cptr_t match = strstr(h.start, n.start); + return (sz_ssize_t)(match ? match - h.start : h.length); + }}, + {"std::search", + [](sz_string_view_t h, sz_string_view_t n) { + auto match = std::search(h.start, h.start + h.length, n.start, n.start + n.length); + return (sz_ssize_t)(match - h.start); + }}, + {"std::search", + [](sz_string_view_t h, sz_string_view_t n) { + auto match = + std::search(h.start, h.start + h.length, std::boyer_moore_searcher(n.start, n.start + n.length)); + return (sz_ssize_t)(match - h.start); + }}, + {"std::search", + [](sz_string_view_t h, sz_string_view_t n) { + auto match = std::search(h.start, h.start + h.length, + std::boyer_moore_horspool_searcher(n.start, n.start + n.length)); + return (sz_ssize_t)(match - h.start); + }}, + }; + return result; +} + +tracked_binary_functions_t find_last_functions() { + auto wrap_sz = [](auto function) -> binary_function_t { + return binary_function_t([function](sz_string_view_t h, sz_string_view_t n) { + sz_cptr_t match = function(h.start, h.length, n.start, n.length); + return (sz_ssize_t)(match ? match - h.start : h.length); + }); + }; + tracked_binary_functions_t result = { + {"std::string_view.rfind", + [](sz_string_view_t h, sz_string_view_t n) { + auto h_view = std::string_view(h.start, h.length); + auto n_view = std::string_view(n.start, n.length); + auto match = h_view.rfind(n_view); + return (sz_ssize_t)(match == std::string_view::npos ? h.length : match); + }}, + {"sz_find_last_serial", wrap_sz(sz_find_last_serial), true}, +#if SZ_USE_X86_AVX512 + {"sz_find_last_avx512", wrap_sz(sz_find_last_avx512), true}, +#endif +#if SZ_USE_ARM_NEON + {"sz_find_last_neon", wrap_sz(sz_find_last_neon), true}, +#endif + {"std::search", + [](sz_string_view_t h, sz_string_view_t n) { + auto h_view = std::string_view(h.start, h.length); + auto n_view = std::string_view(n.start, n.length); + auto match = std::search(h_view.rbegin(), h_view.rend(), n_view.rbegin(), n_view.rend()); + auto offset_from_end = (sz_ssize_t)(match - h_view.rbegin()); + return h.length - offset_from_end; + }}, + {"std::search", + [](sz_string_view_t h, sz_string_view_t n) { + auto h_view = std::string_view(h.start, h.length); + auto n_view = std::string_view(n.start, n.length); + auto match = + std::search(h_view.rbegin(), h_view.rend(), std::boyer_moore_searcher(n_view.rbegin(), n_view.rend())); + auto offset_from_end = (sz_ssize_t)(match - h_view.rbegin()); + return h.length - offset_from_end; + }}, + {"std::search", + [](sz_string_view_t h, sz_string_view_t n) { + auto h_view = std::string_view(h.start, h.length); + auto n_view = std::string_view(n.start, n.length); + auto match = std::search(h_view.rbegin(), h_view.rend(), + std::boyer_moore_horspool_searcher(n_view.rbegin(), n_view.rend())); + auto offset_from_end = (sz_ssize_t)(match - h_view.rbegin()); + return h.length - offset_from_end; + }}, + }; + return result; +} + +/** + * @brief Evaluation for search string operations: find. + */ +template +void evaluate_find_operations(std::string_view content_original, strings_at &&strings, + tracked_binary_functions_t &&variants) { + + for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { + auto &variant = variants[variant_idx]; + + // Tests + if (variant.function && variant.needs_testing) { + loop_over_words(strings, [&](sz_string_view_t str_n) { + sz_string_view_t str_h = {content_original.data(), content_original.size()}; + while (true) { + auto baseline = variants[0].function(str_h, str_n); + auto result = variant.function(str_h, str_n); + if (result != baseline) { + ++variant.failed_count; + if (variant.failed_strings.empty()) { + variant.failed_strings.push_back({str_h.start, baseline + str_n.length}); + variant.failed_strings.push_back({str_n.start, str_n.length}); + } + } + + if (baseline == str_h.length) break; + str_h.start += baseline + 1; + str_h.length -= baseline + 1; + } + + return content_original.size(); + }); + } + + // Benchmarks + if (variant.function) { + variant.results = loop_over_words(strings, [&](sz_string_view_t str_n) { + sz_string_view_t str_h = {content_original.data(), content_original.size()}; + auto result = variant.function(str_h, str_n); + while (result != str_h.length) { + str_h.start += result + 1, str_h.length -= result + 1; + result = variant.function(str_h, str_n); + do_not_optimize(result); + } + return result; + }); + } + + variant.print(); + } +} + +/** + * @brief Evaluation for reverse order search string operations: find. + */ +template +void evaluate_find_last_operations(std::string_view content_original, strings_at &&strings, + tracked_binary_functions_t &&variants) { + + for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { + auto &variant = variants[variant_idx]; + + // Tests + if (variant.function && variant.needs_testing) { + loop_over_words(strings, [&](sz_string_view_t str_n) { + sz_string_view_t str_h = {content_original.data(), content_original.size()}; + while (true) { + auto baseline = variants[0].function(str_h, str_n); + auto result = variant.function(str_h, str_n); + if (result != baseline) { + ++variant.failed_count; + if (variant.failed_strings.empty()) { + variant.failed_strings.push_back({str_h.start + baseline, str_h.start + str_h.length}); + variant.failed_strings.push_back({str_n.start, str_n.length}); + } + } + + if (baseline == str_h.length) break; + str_h.length = baseline; + } + + return content_original.size(); + }); + } + + // Benchmarks + if (variant.function) { + std::size_t bytes_processed = 0; + std::size_t mask = content_original.size() - 1; + variant.results = loop_over_words(strings, [&](sz_string_view_t str_n) { + sz_string_view_t str_h = {content_original.data(), content_original.size()}; + str_h.length -= bytes_processed & mask; + auto result = variant.function(str_h, str_n); + bytes_processed += (str_h.length - result) + str_n.length; + return result; + }); + } + + variant.print(); + } +} + +template +void evaluate_all(std::string_view content_original, strings_at &&strings) { + if (strings.size() == 0) return; + + evaluate_find_operations(content_original, strings, find_functions()); + evaluate_find_last_operations(content_original, strings, find_last_functions()); +} + +int main(int argc, char const **argv) { + std::printf("StringZilla. Starting search benchmarks.\n"); + + dataset_t dataset = make_dataset(argc, argv); + + // Baseline benchmarks for real words, coming in all lengths + std::printf("Benchmarking on real words:\n"); + evaluate_all(dataset.text, dataset.tokens); + + // Run benchmarks on tokens of different length + for (std::size_t token_length : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32}) { + std::printf("Benchmarking on real words of length %zu:\n", token_length); + evaluate_all(dataset.text, dataset.tokens_of_length(token_length)); + } + + // Run bechnmarks on abstract tokens of different length + for (std::size_t token_length : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32}) { + std::printf("Benchmarking for missing tokens of length %zu:\n", token_length); + evaluate_all(dataset.text, std::vector { + std::string(token_length, '\1'), + std::string(token_length, '\2'), + std::string(token_length, '\3'), + std::string(token_length, '\4'), + }); + } + + std::printf("All benchmarks passed.\n"); + return 0; +} \ No newline at end of file diff --git a/scripts/search_bench.py b/scripts/bench_search.py similarity index 87% rename from scripts/search_bench.py rename to scripts/bench_search.py index a7c864fb..9a65d7ff 100644 --- a/scripts/search_bench.py +++ b/scripts/bench_search.py @@ -21,11 +21,6 @@ def log_functionality( stringzilla_str: Str, stringzilla_file: File, ): - log("str.contains", bytes_length, lambda: pattern in pythonic_str) - log("Str.contains", bytes_length, lambda: pattern in stringzilla_str) - if stringzilla_file: - log("File.contains", bytes_length, lambda: pattern in stringzilla_file) - log("str.count", bytes_length, lambda: pythonic_str.count(pattern)) log("Str.count", bytes_length, lambda: stringzilla_str.count(pattern)) if stringzilla_file: @@ -43,7 +38,7 @@ def log_functionality( def bench( - needle: str, + needle: str = None, haystack_path: str = None, haystack_pattern: str = None, haystack_length: int = None, diff --git a/scripts/bench_similarity.cpp b/scripts/bench_similarity.cpp new file mode 100644 index 00000000..b8cad4ae --- /dev/null +++ b/scripts/bench_similarity.cpp @@ -0,0 +1,87 @@ +/** + * @file bench_similarity.cpp + * @brief Benchmarks string similarity computations. + * + * This file is the sibling of `bench_sort.cpp`, `bench_search.cpp` and `bench_token.cpp`. + * It accepts a file with a list of words, and benchmarks the levenshtein edit-distance computations, + * alignment scores, and fingerprinting techniques combined with the Hamming distance. + */ +#include + +using namespace ashvardanian::stringzilla::scripts; + +using temporary_memory_t = std::vector; +temporary_memory_t temporary_memory; + +static sz_ptr_t allocate_from_vector(sz_size_t length, void *handle) { + temporary_memory_t &vec = *reinterpret_cast(handle); + if (vec.size() < length) vec.resize(length); + return vec.data(); +} + +static void free_from_vector(sz_ptr_t buffer, sz_size_t length, void *handle) {} + +tracked_binary_functions_t distance_functions() { + // Populate the unary substitutions matrix + static constexpr std::size_t max_length = 256; + static std::vector unary_substitution_costs; + unary_substitution_costs.resize(max_length * max_length); + for (std::size_t i = 0; i != max_length; ++i) + for (std::size_t j = 0; j != max_length; ++j) unary_substitution_costs[i * max_length + j] = (i == j ? 0 : 1); + + // Two rows of the Levenshtein matrix will occupy this much: + temporary_memory.resize((max_length + 1) * 2 * sizeof(sz_size_t)); + sz_memory_allocator_t alloc; + alloc.allocate = &allocate_from_vector; + alloc.free = &free_from_vector; + alloc.handle = &temporary_memory; + + auto wrap_sz_distance = [alloc](auto function) -> binary_function_t { + return binary_function_t([function, alloc](sz_string_view_t a, sz_string_view_t b) { + a.length = sz_min_of_two(a.length, max_length); + b.length = sz_min_of_two(b.length, max_length); + return (sz_ssize_t)function(a.start, a.length, b.start, b.length, max_length, &alloc); + }); + }; + auto wrap_sz_scoring = [alloc](auto function) -> binary_function_t { + return binary_function_t([function, alloc](sz_string_view_t a, sz_string_view_t b) { + a.length = sz_min_of_two(a.length, max_length); + b.length = sz_min_of_two(b.length, max_length); + return (sz_ssize_t)function(a.start, a.length, b.start, b.length, 1, unary_substitution_costs.data(), + &alloc); + }); + }; + tracked_binary_functions_t result = { + {"sz_edit_distance", wrap_sz_distance(sz_edit_distance_serial)}, + {"sz_alignment_score", wrap_sz_scoring(sz_alignment_score_serial), true}, +#if SZ_USE_X86_AVX512 + {"sz_edit_distance_avx512", wrap_sz_distance(sz_edit_distance_avx512), true}, +#endif + }; + return result; +} + +template +void evaluate_all(strings_at &&strings) { + if (strings.size() == 0) return; + evaluate_binary_operations(strings, distance_functions()); +} + +int main(int argc, char const **argv) { + std::printf("StringZilla. Starting similarity benchmarks.\n"); + + dataset_t dataset = make_dataset(argc, argv); + + // Baseline benchmarks for real words, coming in all lengths + std::printf("Benchmarking on real words:\n"); + evaluate_all(dataset.tokens); + + // Run benchmarks on tokens of different length + for (std::size_t token_length : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32}) { + std::printf("Benchmarking on real words of length %zu:\n", token_length); + evaluate_all(dataset.tokens_of_length(token_length)); + } + + std::printf("All benchmarks passed.\n"); + return 0; +} \ No newline at end of file diff --git a/scripts/similarity_bench.py b/scripts/bench_similarity.py similarity index 100% rename from scripts/similarity_bench.py rename to scripts/bench_similarity.py diff --git a/scripts/sort_bench.cpp b/scripts/bench_sort.cpp similarity index 88% rename from scripts/sort_bench.cpp rename to scripts/bench_sort.cpp index ea52156e..cf34e6e0 100644 --- a/scripts/sort_bench.cpp +++ b/scripts/bench_sort.cpp @@ -1,15 +1,14 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include + +/** + * @file bench_sort.cpp + * @brief Benchmarks sorting, partitioning, and merging operations on string sequences. + * + * This file is the sibling of `bench_similarity.cpp`, `bench_search.cpp` and `bench_token.cpp`. + * It accepts a file with a list of words, and benchmarks the sorting operations on them. + */ +#include + +using namespace ashvardanian::stringzilla::scripts; using strings_t = std::vector; using idx_t = sz_size_t; @@ -27,14 +26,14 @@ static sz_size_t get_length(sz_sequence_t const *array_c, sz_size_t i) { return array[i].size(); } -static int is_less(sz_sequence_t const *array_c, sz_size_t i, sz_size_t j) { +static sz_bool_t is_less(sz_sequence_t const *array_c, sz_size_t i, sz_size_t j) { strings_t const &array = *reinterpret_cast(array_c->handle); - return array[i] < array[j]; + return (sz_bool_t)(array[i] < array[j]); } -static int has_under_four_chars(sz_sequence_t const *array_c, sz_size_t i) { +static sz_bool_t has_under_four_chars(sz_sequence_t const *array_c, sz_size_t i) { strings_t const &array = *reinterpret_cast(array_c->handle); - return array[i].size() < 4; + return (sz_bool_t)(array[i].size() < 4); } #pragma endregion @@ -145,19 +144,10 @@ void bench_permute(char const *name, strings_t &strings, permute_t &permute, alg std::printf("Elapsed time is %.2lf miliseconds/iteration for %s.\n", milisecs, name); } -int main(int, char const **) { - std::printf("Hey, Ash!\n"); - - strings_t strings; - populate_from_file("leipzig1M.txt", strings, 1000000); - std::size_t mean_bytes = 0; - for (std::string const &str : strings) mean_bytes += str.size(); - mean_bytes /= strings.size(); - std::printf("Parsed the file with %zu words of %zu mean length!\n", strings.size(), mean_bytes); - - std::string full_text; - full_text.reserve(mean_bytes + strings.size() * 2); - for (std::string const &str : strings) full_text.append(str), full_text.push_back(' '); +int main(int argc, char const **argv) { + std::printf("StringZilla. Starting sorting benchmarks.\n"); + dataset_t dataset = make_dataset(argc, argv); + strings_t &strings = dataset.tokens; permute_t permute_base, permute_new; permute_base.resize(strings.size()); diff --git a/scripts/bench_token.cpp b/scripts/bench_token.cpp new file mode 100644 index 00000000..17066f28 --- /dev/null +++ b/scripts/bench_token.cpp @@ -0,0 +1,104 @@ +/** + * @file bench_token.cpp + * @brief Benchmarks token-level operations like hashing, equality, ordering, and copies. + * + * This file is the sibling of `bench_sort.cpp`, `bench_search.cpp` and `bench_similarity.cpp`. + */ +#include + +using namespace ashvardanian::stringzilla::scripts; + +tracked_unary_functions_t hashing_functions() { + auto wrap_sz = [](auto function) -> unary_function_t { + return unary_function_t([function](sz_string_view_t s) { return (sz_ssize_t)function(s.start, s.length); }); + }; + tracked_unary_functions_t result = { + {"sz_hash_serial", wrap_sz(sz_hash_serial)}, +#if SZ_USE_X86_AVX512 + {"sz_hash_avx512", wrap_sz(sz_hash_avx512), true}, +#endif +#if SZ_USE_ARM_NEON + {"sz_hash_neon", wrap_sz(sz_hash_neon), true}, +#endif + {"std::hash", + [](sz_string_view_t s) { + return (sz_ssize_t)std::hash {}({s.start, s.length}); + }}, + }; + return result; +} + +tracked_binary_functions_t equality_functions() { + auto wrap_sz = [](auto function) -> binary_function_t { + return binary_function_t([function](sz_string_view_t a, sz_string_view_t b) { + return (sz_ssize_t)(a.length == b.length && function(a.start, b.start, a.length)); + }); + }; + tracked_binary_functions_t result = { + {"std::string_view.==", + [](sz_string_view_t a, sz_string_view_t b) { + return (sz_ssize_t)(std::string_view(a.start, a.length) == std::string_view(b.start, b.length)); + }}, + {"sz_equal_serial", wrap_sz(sz_equal_serial), true}, +#if SZ_USE_X86_AVX512 + {"sz_equal_avx512", wrap_sz(sz_equal_avx512), true}, +#endif + {"memcmp", + [](sz_string_view_t a, sz_string_view_t b) { + return (sz_ssize_t)(a.length == b.length && memcmp(a.start, b.start, a.length) == 0); + }}, + }; + return result; +} + +tracked_binary_functions_t ordering_functions() { + auto wrap_sz = [](auto function) -> binary_function_t { + return binary_function_t([function](sz_string_view_t a, sz_string_view_t b) { + return (sz_ssize_t)function(a.start, a.length, b.start, b.length); + }); + }; + tracked_binary_functions_t result = { + {"std::string_view.compare", + [](sz_string_view_t a, sz_string_view_t b) { + auto order = std::string_view(a.start, a.length).compare(std::string_view(b.start, b.length)); + return (sz_ssize_t)(order == 0 ? sz_equal_k : (order < 0 ? sz_less_k : sz_greater_k)); + }}, + {"sz_order_serial", wrap_sz(sz_order_serial), true}, + {"memcmp", + [](sz_string_view_t a, sz_string_view_t b) { + auto order = memcmp(a.start, b.start, a.length < b.length ? a.length : b.length); + return order != 0 ? (a.length == b.length ? (order < 0 ? sz_less_k : sz_greater_k) + : (a.length < b.length ? sz_less_k : sz_greater_k)) + : sz_equal_k; + }}, + }; + return result; +} + +template +void evaluate_all(strings_at &&strings) { + if (strings.size() == 0) return; + + evaluate_unary_operations(strings, hashing_functions()); + evaluate_binary_operations(strings, equality_functions()); + evaluate_binary_operations(strings, ordering_functions()); +} + +int main(int argc, char const **argv) { + std::printf("StringZilla. Starting token-level benchmarks.\n"); + + dataset_t dataset = make_dataset(argc, argv); + + // Baseline benchmarks for real words, coming in all lengths + std::printf("Benchmarking on real words:\n"); + evaluate_all(dataset.tokens); + + // Run benchmarks on tokens of different length + for (std::size_t token_length : {1, 2, 3, 4, 5, 6, 7, 8, 16, 32}) { + std::printf("Benchmarking on real words of length %zu:\n", token_length); + evaluate_all(dataset.tokens_of_length(token_length)); + } + + std::printf("All benchmarks passed.\n"); + return 0; +} \ No newline at end of file diff --git a/scripts/similarity_fuzz.py b/scripts/fuzz.py similarity index 100% rename from scripts/similarity_fuzz.py rename to scripts/fuzz.py diff --git a/scripts/random_baseline.py b/scripts/random_baseline.py deleted file mode 100644 index 3d62d820..00000000 --- a/scripts/random_baseline.py +++ /dev/null @@ -1,15 +0,0 @@ -import random, time -from typing import Union, Optional -from random import choice, randint -from string import ascii_lowercase - - -def get_random_string( - length: Optional[int] = None, - variability: Optional[int] = None, -) -> str: - if length is None: - length = randint(3, 300) - if variability is None: - variability = len(ascii_lowercase) - return "".join(choice(ascii_lowercase[:variability]) for _ in range(length)) diff --git a/scripts/random_stress.py b/scripts/random_stress.py deleted file mode 100644 index 0f2daad2..00000000 --- a/scripts/random_stress.py +++ /dev/null @@ -1,20 +0,0 @@ -# PyTest + Cppyy test of the random string generators and related utility functions -# -import pytest -import cppyy - -cppyy.include("include/stringzilla/stringzilla.h") -cppyy.cppdef( - """ -sz_u32_t native_division(sz_u8_t number, sz_u8_t divisor) { - return sz_u8_divide(number, divisor); -} -""" -) - - -@pytest.mark.parametrize("number", range(0, 256)) -@pytest.mark.parametrize("divisor", range(2, 256)) -def test_fast_division(number: int, divisor: int): - sz_u8_divide = cppyy.gbl.native_division - assert (number // divisor) == sz_u8_divide(number, divisor) diff --git a/scripts/search_bench.cpp b/scripts/search_bench.cpp deleted file mode 100644 index 0e44d1a6..00000000 --- a/scripts/search_bench.cpp +++ /dev/null @@ -1,643 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -using seconds_t = double; -using unary_function_t = std::function; -using binary_function_t = std::function; - -struct loop_over_words_result_t { - std::size_t iterations = 0; - std::size_t bytes_passed = 0; - seconds_t seconds = 0; -}; - -/** - * @brief Wrapper for a single execution backend. - */ -template -struct tracked_function_gt { - std::string name {""}; - function_at function {nullptr}; - bool needs_testing {false}; - - std::size_t failed_count {0}; - std::vector failed_strings {}; - loop_over_words_result_t results {}; - - void print() const { - char const *format; - // Now let's print in the format: - // - name, up to 20 characters - // - throughput in GB/s with up to 3 significant digits, 10 characters - // - call latency in ns with up to 1 significant digit, 10 characters - // - number of failed tests, 10 characters - // - first example of a failed test, up to 20 characters - if constexpr (std::is_same()) - format = "%-20s %10.3f GB/s %10.1f ns %10zu %s %s\n"; - else - format = "%-20s %10.3f GB/s %10.1f ns %10zu %s\n"; - std::printf(format, name.c_str(), results.bytes_passed / results.seconds / 1.e9, - results.seconds * 1e9 / results.iterations, failed_count, - failed_strings.size() ? failed_strings[0].c_str() : "", - failed_strings.size() ? failed_strings[1].c_str() : ""); - } -}; - -using tracked_unary_functions_t = std::vector>; -using tracked_binary_functions_t = std::vector>; - -#ifdef NDEBUG // Make debugging faster -#define run_tests_m 1 -#define default_seconds_m 10 -#else -#define run_tests_m 1 -#define default_seconds_m 10 -#endif - -using temporary_memory_t = std::vector; - -std::string content_original; -std::vector content_words; -std::vector unary_substitution_costs; -temporary_memory_t temporary_memory; - -template -inline void do_not_optimize(value_at &&value) { - asm volatile("" : "+r"(value) : : "memory"); -} - -inline sz_string_view_t sz_string_view(std::string const &str) { return {str.data(), str.size()}; }; - -sz_ptr_t _sz_memory_allocate_from_vector(sz_size_t length, void *handle) { - temporary_memory_t &vec = *reinterpret_cast(handle); - if (vec.size() < length) vec.resize(length); - return vec.data(); -} - -void _sz_memory_free_from_vector(sz_ptr_t buffer, sz_size_t length, void *handle) {} - -std::string read_file(std::string path) { - std::ifstream stream(path); - if (!stream.is_open()) { throw std::runtime_error("Failed to open file: " + path); } - return std::string((std::istreambuf_iterator(stream)), std::istreambuf_iterator()); -} - -std::vector tokenize(std::string_view str) { - std::vector words; - std::size_t start = 0; - for (std::size_t end = 0; end <= str.length(); ++end) { - if (end == str.length() || std::isspace(str[end])) { - if (start < end) words.push_back({&str[start], end - start}); - start = end + 1; - } - } - return words; -} - -sz_string_view_t random_slice(sz_string_view_t full_text, std::size_t min_length = 2, std::size_t max_length = 8) { - std::size_t length = std::rand() % (max_length - min_length) + min_length; - std::size_t offset = std::rand() % (full_text.length - length); - return {full_text.start + offset, length}; -} - -std::size_t round_down_to_power_of_two(std::size_t n) { - if (n == 0) return 0; - std::size_t most_siginificant_bit_position = 0; - while (n > 1) n >>= 1, most_siginificant_bit_position++; - return static_cast(1) << most_siginificant_bit_position; -} - -tracked_unary_functions_t hashing_functions() { - auto wrap_sz = [](auto function) -> unary_function_t { - return unary_function_t([function](sz_string_view_t s) { return (sz_ssize_t)function(s.start, s.length); }); - }; - return { - {"sz_hash_serial", wrap_sz(sz_hash_serial)}, -#if SZ_USE_X86_AVX512 - {"sz_hash_avx512", wrap_sz(sz_hash_avx512), true}, -#endif -#if SZ_USE_ARM_NEON - {"sz_hash_neon", wrap_sz(sz_hash_neon), true}, -#endif - {"std::hash", [](sz_string_view_t s) { - return (sz_ssize_t)std::hash {}({s.start, s.length}); - }}, - }; -} - -inline tracked_binary_functions_t equality_functions() { - auto wrap_sz = [](auto function) -> binary_function_t { - return binary_function_t([function](sz_string_view_t a, sz_string_view_t b) { - return (sz_ssize_t)(a.length == b.length && function(a.start, b.start, a.length)); - }); - }; - return { - {"std::string_view.==", - [](sz_string_view_t a, sz_string_view_t b) { - return (sz_ssize_t)(std::string_view(a.start, a.length) == std::string_view(b.start, b.length)); - }}, - {"sz_equal_serial", wrap_sz(sz_equal_serial), true}, -#if SZ_USE_X86_AVX512 - {"sz_equal_avx512", wrap_sz(sz_equal_avx512), true}, -#endif - {"memcmp", [](sz_string_view_t a, sz_string_view_t b) { - return (sz_ssize_t)(a.length == b.length && memcmp(a.start, b.start, a.length) == 0); - }}, - }; -} - -inline tracked_binary_functions_t ordering_functions() { - auto wrap_sz = [](auto function) -> binary_function_t { - return binary_function_t([function](sz_string_view_t a, sz_string_view_t b) { - return (sz_ssize_t)function(a.start, a.length, b.start, b.length); - }); - }; - return { - {"std::string_view.compare", - [](sz_string_view_t a, sz_string_view_t b) { - auto order = std::string_view(a.start, a.length).compare(std::string_view(b.start, b.length)); - return (sz_ssize_t)(order == 0 ? sz_equal_k : (order < 0 ? sz_less_k : sz_greater_k)); - }}, - {"sz_order_serial", wrap_sz(sz_order_serial), true}, - {"memcmp", - [](sz_string_view_t a, sz_string_view_t b) { - auto order = memcmp(a.start, b.start, a.length < b.length ? a.length : b.length); - return order != 0 ? (a.length == b.length ? (order < 0 ? sz_less_k : sz_greater_k) - : (a.length < b.length ? sz_less_k : sz_greater_k)) - : sz_equal_k; - }}, - }; -} - -inline tracked_binary_functions_t find_functions() { - auto wrap_sz = [](auto function) -> binary_function_t { - return binary_function_t([function](sz_string_view_t h, sz_string_view_t n) { - sz_cptr_t match = function(h.start, h.length, n.start, n.length); - return (sz_ssize_t)(match ? match - h.start : h.length); - }); - }; - return { - {"std::string_view.find", - [](sz_string_view_t h, sz_string_view_t n) { - auto h_view = std::string_view(h.start, h.length); - auto n_view = std::string_view(n.start, n.length); - auto match = h_view.find(n_view); - return (sz_ssize_t)(match == std::string_view::npos ? h.length : match); - }}, - {"sz_find_serial", wrap_sz(sz_find_serial), true}, -#if SZ_USE_X86_AVX512 - {"sz_find_avx512", wrap_sz(sz_find_avx512), true}, -#endif -#if SZ_USE_ARM_NEON - {"sz_find_neon", wrap_sz(sz_find_neon), true}, -#endif - {"strstr", - [](sz_string_view_t h, sz_string_view_t n) { - sz_cptr_t match = strstr(h.start, n.start); - return (sz_ssize_t)(match ? match - h.start : h.length); - }}, - {"std::search", - [](sz_string_view_t h, sz_string_view_t n) { - auto match = std::search(h.start, h.start + h.length, n.start, n.start + n.length); - return (sz_ssize_t)(match - h.start); - }}, - {"std::search", - [](sz_string_view_t h, sz_string_view_t n) { - auto match = - std::search(h.start, h.start + h.length, std::boyer_moore_searcher(n.start, n.start + n.length)); - return (sz_ssize_t)(match - h.start); - }}, - {"std::search", [](sz_string_view_t h, sz_string_view_t n) { - auto match = std::search(h.start, h.start + h.length, - std::boyer_moore_horspool_searcher(n.start, n.start + n.length)); - return (sz_ssize_t)(match - h.start); - }}, - }; -} - -inline tracked_binary_functions_t find_last_functions() { - auto wrap_sz = [](auto function) -> binary_function_t { - return binary_function_t([function](sz_string_view_t h, sz_string_view_t n) { - sz_cptr_t match = function(h.start, h.length, n.start, n.length); - return (sz_ssize_t)(match ? match - h.start : h.length); - }); - }; - return { - {"std::string_view.rfind", - [](sz_string_view_t h, sz_string_view_t n) { - auto h_view = std::string_view(h.start, h.length); - auto n_view = std::string_view(n.start, n.length); - auto match = h_view.rfind(n_view); - return (sz_ssize_t)(match == std::string_view::npos ? h.length : match); - }}, - {"sz_find_last_serial", wrap_sz(sz_find_last_serial), true}, -#if SZ_USE_X86_AVX512 - {"sz_find_last_avx512", wrap_sz(sz_find_last_avx512), true}, -#endif -#if SZ_USE_ARM_NEON - {"sz_find_last_neon", wrap_sz(sz_find_last_neon), true}, -#endif - {"std::search", - [](sz_string_view_t h, sz_string_view_t n) { - auto h_view = std::string_view(h.start, h.length); - auto n_view = std::string_view(n.start, n.length); - auto match = std::search(h_view.rbegin(), h_view.rend(), n_view.rbegin(), n_view.rend()); - auto offset_from_end = (sz_ssize_t)(match - h_view.rbegin()); - return h.length - offset_from_end; - }}, - {"std::search", - [](sz_string_view_t h, sz_string_view_t n) { - auto h_view = std::string_view(h.start, h.length); - auto n_view = std::string_view(n.start, n.length); - auto match = std::search(h_view.rbegin(), h_view.rend(), - std::boyer_moore_searcher(n_view.rbegin(), n_view.rend())); - auto offset_from_end = (sz_ssize_t)(match - h_view.rbegin()); - return h.length - offset_from_end; - }}, - {"std::search", [](sz_string_view_t h, sz_string_view_t n) { - auto h_view = std::string_view(h.start, h.length); - auto n_view = std::string_view(n.start, n.length); - auto match = std::search(h_view.rbegin(), h_view.rend(), - std::boyer_moore_horspool_searcher(n_view.rbegin(), n_view.rend())); - auto offset_from_end = (sz_ssize_t)(match - h_view.rbegin()); - return h.length - offset_from_end; - }}, - }; -} - -inline tracked_binary_functions_t distance_functions() { - // Populate the unary substitutions matrix - static constexpr std::size_t max_length = 256; - unary_substitution_costs.resize(max_length * max_length); - for (std::size_t i = 0; i != max_length; ++i) - for (std::size_t j = 0; j != max_length; ++j) unary_substitution_costs[i * max_length + j] = (i == j ? 0 : 1); - - // Two rows of the Levenshtein matrix will occupy this much: - temporary_memory.resize((max_length + 1) * 2 * sizeof(sz_size_t)); - sz_memory_allocator_t alloc; - alloc.allocate = _sz_memory_allocate_from_vector; - alloc.free = _sz_memory_free_from_vector; - alloc.handle = &temporary_memory; - - auto wrap_sz_distance = [alloc](auto function) -> binary_function_t { - return binary_function_t([function, alloc](sz_string_view_t a, sz_string_view_t b) { - a.length = sz_min_of_two(a.length, max_length); - b.length = sz_min_of_two(b.length, max_length); - return (sz_ssize_t)function(a.start, a.length, b.start, b.length, max_length, &alloc); - }); - }; - auto wrap_sz_scoring = [alloc](auto function) -> binary_function_t { - return binary_function_t([function, alloc](sz_string_view_t a, sz_string_view_t b) { - a.length = sz_min_of_two(a.length, max_length); - b.length = sz_min_of_two(b.length, max_length); - return (sz_ssize_t)function(a.start, a.length, b.start, b.length, 1, unary_substitution_costs.data(), - &alloc); - }); - }; - return { - {"sz_edit_distance", wrap_sz_distance(sz_edit_distance_serial)}, - {"sz_alignment_score", wrap_sz_scoring(sz_alignment_score_serial), true}, -#if SZ_USE_X86_AVX512 - {"sz_edit_distance_avx512", wrap_sz_distance(sz_edit_distance_avx512), true}, -#endif - }; -} - -/** - * @brief Loop over all elements in a dataset in somewhat random order, benchmarking the function cost. - * @param strings Strings to loop over. Length must be a power of two. - * @param function Function to be applied to each `sz_string_view_t`. Must return the number of bytes processed. - * @return Number of seconds per iteration. - */ -template -loop_over_words_result_t loop_over_words(strings_at &&strings, function_at &&function, - seconds_t max_time = default_seconds_m) { - - namespace stdc = std::chrono; - using stdcc = stdc::high_resolution_clock; - stdcc::time_point t1 = stdcc::now(); - loop_over_words_result_t result; - std::size_t lookup_mask = round_down_to_power_of_two(strings.size()) - 1; - - while (true) { - // Unroll a few iterations, to avoid some for-loops overhead and minimize impact of time-tracking - { - result.bytes_passed += function(sz_string_view(strings[(++result.iterations) & lookup_mask])); - result.bytes_passed += function(sz_string_view(strings[(++result.iterations) & lookup_mask])); - result.bytes_passed += function(sz_string_view(strings[(++result.iterations) & lookup_mask])); - result.bytes_passed += function(sz_string_view(strings[(++result.iterations) & lookup_mask])); - } - - stdcc::time_point t2 = stdcc::now(); - result.seconds = stdc::duration_cast(t2 - t1).count() / 1.e9; - if (result.seconds > max_time) break; - } - - return result; -} - -/** - * @brief Evaluation for unary string operations: hashing. - */ -template -void evaluate_unary_operations(strings_at &&strings, tracked_unary_functions_t &&variants) { - - for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { - auto &variant = variants[variant_idx]; - - // Tests - if (variant.function && variant.needs_testing) { - loop_over_words(strings, [&](sz_string_view_t str) { - auto baseline = variants[0].function(str); - auto result = variant.function(str); - if (result != baseline) { - ++variant.failed_count; - if (variant.failed_strings.empty()) { variant.failed_strings.push_back({str.start, str.length}); } - } - return str.length; - }); - } - - // Benchmarks - if (variant.function) { - variant.results = loop_over_words(strings, [&](sz_string_view_t str) { - do_not_optimize(variant.function(str)); - return str.length; - }); - } - - variant.print(); - } -} - -/** - * @brief Loop over all elements in a dataset, benchmarking the function cost. - * @param strings Strings to loop over. Length must be a power of two. - * @param function Function to be applied to pairs of `sz_string_view_t`. Must return the number of bytes - * processed. - * @return Number of seconds per iteration. - */ -template -loop_over_words_result_t loop_over_pairs_of_words(strings_at &&strings, function_at &&function, - seconds_t max_time = default_seconds_m) { - - namespace stdc = std::chrono; - using stdcc = stdc::high_resolution_clock; - stdcc::time_point t1 = stdcc::now(); - loop_over_words_result_t result; - std::size_t lookup_mask = round_down_to_power_of_two(strings.size()) - 1; - - while (true) { - // Unroll a few iterations, to avoid some for-loops overhead and minimize impact of time-tracking - { - result.bytes_passed += - function(sz_string_view(strings[(++result.iterations) & lookup_mask]), - sz_string_view(strings[(result.iterations * 18446744073709551557ull) & lookup_mask])); - result.bytes_passed += - function(sz_string_view(strings[(++result.iterations) & lookup_mask]), - sz_string_view(strings[(result.iterations * 18446744073709551557ull) & lookup_mask])); - result.bytes_passed += - function(sz_string_view(strings[(++result.iterations) & lookup_mask]), - sz_string_view(strings[(result.iterations * 18446744073709551557ull) & lookup_mask])); - result.bytes_passed += - function(sz_string_view(strings[(++result.iterations) & lookup_mask]), - sz_string_view(strings[(result.iterations * 18446744073709551557ull) & lookup_mask])); - } - - stdcc::time_point t2 = stdcc::now(); - result.seconds = stdc::duration_cast(t2 - t1).count() / 1.e9; - if (result.seconds > max_time) break; - } - - return result; -} - -/** - * @brief Evaluation for binary string operations: equality, ordering, prefix, suffix, distance. - */ -template -void evaluate_binary_operations(strings_at &&strings, tracked_binary_functions_t &&variants) { - - for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { - auto &variant = variants[variant_idx]; - - // Tests - if (variant.function && variant.needs_testing) { - loop_over_pairs_of_words(strings, [&](sz_string_view_t str_a, sz_string_view_t str_b) { - auto baseline = variants[0].function(str_a, str_b); - auto result = variant.function(str_a, str_b); - if (result != baseline) { - ++variant.failed_count; - if (variant.failed_strings.empty()) { - variant.failed_strings.push_back({str_a.start, str_a.length}); - variant.failed_strings.push_back({str_b.start, str_b.length}); - } - } - return str_a.length + str_b.length; - }); - } - - // Benchmarks - if (variant.function) { - variant.results = loop_over_pairs_of_words(strings, [&](sz_string_view_t str_a, sz_string_view_t str_b) { - do_not_optimize(variant.function(str_a, str_b)); - return str_a.length + str_b.length; - }); - } - - variant.print(); - } -} - -/** - * @brief Evaluation for search string operations: find. - */ -template -void evaluate_find_operations(strings_at &&strings, tracked_binary_functions_t &&variants) { - - for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { - auto &variant = variants[variant_idx]; - - // Tests - if (variant.function && variant.needs_testing) { - loop_over_words(strings, [&](sz_string_view_t str_n) { - sz_string_view_t str_h = {content_original.data(), content_original.size()}; - while (true) { - auto baseline = variants[0].function(str_h, str_n); - auto result = variant.function(str_h, str_n); - if (result != baseline) { - ++variant.failed_count; - if (variant.failed_strings.empty()) { - variant.failed_strings.push_back({str_h.start, baseline + str_n.length}); - variant.failed_strings.push_back({str_n.start, str_n.length}); - } - } - - if (baseline == str_h.length) break; - str_h.start += baseline + 1; - str_h.length -= baseline + 1; - } - - return content_original.size(); - }); - } - - // Benchmarks - if (variant.function) { - variant.results = loop_over_words(strings, [&](sz_string_view_t str_n) { - sz_string_view_t str_h = {content_original.data(), content_original.size()}; - auto result = variant.function(str_h, str_n); - while (result != str_h.length) { - str_h.start += result + 1, str_h.length -= result + 1; - result = variant.function(str_h, str_n); - do_not_optimize(result); - } - return result; - }); - } - - variant.print(); - } -} - -/** - * @brief Evaluation for reverse order search string operations: find. - */ -template -void evaluate_find_last_operations(strings_at &&strings, tracked_binary_functions_t &&variants) { - - for (std::size_t variant_idx = 0; variant_idx != variants.size(); ++variant_idx) { - auto &variant = variants[variant_idx]; - - // Tests - if (variant.function && variant.needs_testing) { - loop_over_words(strings, [&](sz_string_view_t str_n) { - sz_string_view_t str_h = {content_original.data(), content_original.size()}; - while (true) { - auto baseline = variants[0].function(str_h, str_n); - auto result = variant.function(str_h, str_n); - if (result != baseline) { - ++variant.failed_count; - if (variant.failed_strings.empty()) { - variant.failed_strings.push_back({str_h.start + baseline, str_h.start + str_h.length}); - variant.failed_strings.push_back({str_n.start, str_n.length}); - } - } - - if (baseline == str_h.length) break; - str_h.length = baseline; - } - - return content_original.size(); - }); - } - - // Benchmarks - if (variant.function) { - std::size_t bytes_processed = 0; - std::size_t mask = content_original.size() - 1; - variant.results = loop_over_words(strings, [&](sz_string_view_t str_n) { - sz_string_view_t str_h = {content_original.data(), content_original.size()}; - str_h.length -= bytes_processed & mask; - auto result = variant.function(str_h, str_n); - bytes_processed += (str_h.length - result) + str_n.length; - return result; - }); - } - - variant.print(); - } -} - -template -void evaluate_all_operations(strings_at &&strings) { - evaluate_unary_operations(strings, hashing_functions()); - evaluate_binary_operations(strings, equality_functions()); - evaluate_binary_operations(strings, ordering_functions()); - evaluate_binary_operations(strings, distance_functions()); - evaluate_find_operations(strings, find_functions()); - evaluate_find_last_operations(strings, find_last_functions()); - - // evaluate_binary_operations(strings, prefix_functions()); - // evaluate_binary_operations(strings, suffix_functions()); -} - -int main(int, char const **) { - std::printf("Hi Ash! ... or is it someone else?!\n"); - - content_original = read_file("leipzig1M.txt"); - content_original.resize(round_down_to_power_of_two(content_original.size())); - - content_words = tokenize(content_original); - content_words.resize(round_down_to_power_of_two(content_words.size())); - -#ifdef NDEBUG // Shuffle only in release mode - std::random_device random_device; - std::mt19937 random_generator(random_device()); - std::shuffle(content_words.begin(), content_words.end(), random_generator); -#endif - - // Report some basic stats about the dataset - std::size_t mean_bytes = 0; - for (auto const &str : content_words) mean_bytes += str.size(); - mean_bytes /= content_words.size(); - std::printf("Parsed the file with %zu words of %zu mean length!\n", content_words.size(), mean_bytes); - - // Baseline benchmarks for real words, coming in all lengths - { - std::printf("Benchmarking for real words:\n"); - evaluate_all_operations(content_words); - } - - // Produce benchmarks for different word lengths, both real and impossible - for (std::size_t word_length : {1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 33, 65}) { - - // Generate some impossible words of that length - std::printf("\n\n"); - std::printf("Benchmarking for abstract tokens of length %zu:\n", word_length); - std::vector words = { - std::string(word_length, '\1'), - std::string(word_length, '\2'), - std::string(word_length, '\3'), - std::string(word_length, '\4'), - }; - evaluate_all_operations(words); - - // Check for some real words of that length - for (auto const &str : words) - if (str.size() == word_length) words.push_back(str); - if (!words.size()) continue; - std::printf("Benchmarking for real words of length %zu:\n", word_length); - evaluate_all_operations(words); - } - - // Now lets test our functionality on longer biological sequences. - // A single human gene is from 300 to 15,000 base pairs long. - // Thole whole human genome is about 3 billion base pairs long. - // The genomes of bacteria are relatively small - E. coli genome is about 4.6 million base pairs long. - // In techniques like PCR (Polymerase Chain Reaction), short DNA sequences called primers are used. - // These are usually 18 to 25 base pairs long. - char aminoacids[] = "ATCG"; - for (std::size_t dna_length : {300, 2000, 15000}) { - std::vector dna_sequences(16); - for (std::size_t i = 0; i != 16; ++i) { - dna_sequences[i].resize(dna_length); - for (std::size_t j = 0; j != dna_length; ++j) dna_sequences[i][j] = aminoacids[std::rand() % 4]; - } - std::printf("Benchmarking for DNA-like sequences of length %zu:\n", dna_length); - evaluate_all_operations(dna_sequences); - } - - return 0; -} \ No newline at end of file diff --git a/scripts/similarity_baseline.py b/scripts/similarity_baseline.py deleted file mode 100644 index 693bd573..00000000 --- a/scripts/similarity_baseline.py +++ /dev/null @@ -1,30 +0,0 @@ -import numpy as np - - -def levenshtein(str1: str, str2: str, whole_matrix: bool = False) -> int: - """Naive Levenshtein edit distance computation using NumPy. Quadratic complexity in time and space.""" - rows = len(str1) + 1 - cols = len(str2) + 1 - distance_matrix = np.zeros((rows, cols), dtype=int) - distance_matrix[0, :] = np.arange(cols) - distance_matrix[:, 0] = np.arange(rows) - for i in range(1, rows): - for j in range(1, cols): - if str1[i - 1] == str2[j - 1]: - cost = 0 - else: - cost = 1 - - distance_matrix[i, j] = min( - distance_matrix[i - 1, j] + 1, # Deletion - distance_matrix[i, j - 1] + 1, # Insertion - distance_matrix[i - 1, j - 1] + cost, # Substitution - ) - - if whole_matrix: - return distance_matrix - return distance_matrix[-1, -1] - - -if __name__ == "__main__": - print(levenshtein("aaaba", "aaaca", True)) diff --git a/scripts/search_test.cpp b/scripts/test.cpp similarity index 93% rename from scripts/search_test.cpp rename to scripts/test.cpp index b3144a4d..ea853532 100644 --- a/scripts/search_test.cpp +++ b/scripts/test.cpp @@ -16,6 +16,12 @@ namespace sz = ashvardanian::stringzilla; using sz::literals::operator""_sz; +/** + * Evaluates the correctness of a "matcher", searching for all the occurences of the `needle_stl` + * in a haystack formed of `haystack_pattern` repeated from one to `max_repeats` times. + * + * @param misalignment The number of bytes to misalign the haystack within the cacheline. + */ template void eval(std::string_view haystack_pattern, std::string_view needle_stl, std::size_t misalignment) { constexpr std::size_t max_repeats = 128; @@ -54,6 +60,10 @@ void eval(std::string_view haystack_pattern, std::string_view needle_stl, std::s } } +/** + * Evaluates the correctness of a "matcher", searching for all the occurences of the `needle_stl`, + * as a substring, as a set of allowed characters, or as a set of disallowed characters, in a haystack. + */ void eval(std::string_view haystack_pattern, std::string_view needle_stl, std::size_t misalignment) { eval< // @@ -94,7 +104,7 @@ void eval(std::string_view haystack_pattern, std::string_view needle_stl) { eval(haystack_pattern, needle_stl, 3); } -int main(int, char const **) { +int main(int argc, char const **argv) { std::printf("Hi Ash! ... or is it someone else?!\n"); std::string_view alphabet = "abcdefghijklmnopqrstuvwxyz"; // 26 characters diff --git a/scripts/unit_test.js b/scripts/test.js similarity index 100% rename from scripts/unit_test.js rename to scripts/test.js diff --git a/scripts/search_test.py b/scripts/test.py similarity index 50% rename from scripts/search_test.py rename to scripts/test.py index ca1841bc..2fd5541a 100644 --- a/scripts/search_test.py +++ b/scripts/test.py @@ -1,14 +1,115 @@ -import random, time -from typing import Union, Optional -from random import choice, randint -from string import ascii_lowercase - -import numpy as np import pytest import stringzilla as sz -from stringzilla import Str, Strs -from scripts.similarity_baseline import levenshtein +from stringzilla import Str + + +def test_unit_construct(): + native = "aaaaa" + big = Str(native) + assert len(big) == len(native) + + +def test_unit_indexing(): + native = "abcdef" + big = Str(native) + for i in range(len(native)): + assert big[i] == native[i] + + +def test_unit_count(): + native = "aaaaa" + big = Str(native) + assert big.count("a") == 5 + assert big.count("aa") == 2 + assert big.count("aa", allowoverlap=True) == 4 + + +def test_unit_contains(): + big = Str("abcdef") + assert "a" in big + assert "ab" in big + assert "xxx" not in big + + +def test_unit_rich_comparisons(): + assert Str("aa") == "aa" + assert Str("aa") < "b" + assert Str("abb")[1:] == "bb" + + +def test_unit_buffer_protocol(): + import numpy as np + + my_str = Str("hello") + arr = np.array(my_str) + assert arr.dtype == np.dtype("c") + assert arr.shape == (len("hello"),) + assert "".join([c.decode("utf-8") for c in arr.tolist()]) == "hello" + + +def test_unit_split(): + native = "token1\ntoken2\ntoken3" + big = Str(native) + assert native.splitlines() == list(big.splitlines()) + assert native.splitlines(True) == list(big.splitlines(keeplinebreaks=True)) + assert native.split("token3") == list(big.split("token3")) + + words = sz.split(big, "\n") + assert len(words) == 3 + assert str(words[0]) == "token1" + assert str(words[2]) == "token3" + + parts = sz.split(big, "\n", keepseparator=True) + assert len(parts) == 3 + assert str(parts[0]) == "token1\n" + assert str(parts[2]) == "token3" + + +def test_unit_sequence(): + native = "p3\np2\np1" + big = Str(native) + + lines = big.splitlines() + assert [2, 1, 0] == list(lines.order()) + + lines.sort() + assert [0, 1, 2] == list(lines.order()) + assert ["p1", "p2", "p3"] == list(lines) + + # Reverse order + assert [2, 1, 0] == list(lines.order(reverse=True)) + lines.sort(reverse=True) + assert ["p3", "p2", "p1"] == list(lines) + + +def test_unit_globals(): + """Validates that the previously unit-tested member methods are also visible as global functions.""" + + assert sz.find("abcdef", "bcdef") == 1 + assert sz.find("abcdef", "x") == -1 + + assert sz.count("abcdef", "x") == 0 + assert sz.count("aaaaa", "a") == 5 + assert sz.count("aaaaa", "aa") == 2 + assert sz.count("aaaaa", "aa", allowoverlap=True) == 4 + + assert sz.edit_distance("aaa", "aaa") == 0 + assert sz.edit_distance("aaa", "bbb") == 3 + assert sz.edit_distance("abababab", "aaaaaaaa") == 4 + assert sz.edit_distance("abababab", "aaaaaaaa", 2) == 2 + assert sz.edit_distance("abababab", "aaaaaaaa", bound=2) == 2 + + +def get_random_string( + length: Optional[int] = None, + variability: Optional[int] = None, +) -> str: + if length is None: + length = randint(3, 300) + if variability is None: + variability = len(ascii_lowercase) + return "".join(choice(ascii_lowercase[:variability]) for _ in range(length)) def is_equal_strings(native_strings, big_strings): @@ -79,7 +180,7 @@ def test_fuzzy_substrings(pattern_length: int, haystack_length: int, variability @pytest.mark.repeat(100) @pytest.mark.parametrize("max_edit_distance", [150]) -def test_levenshtein_insertions(max_edit_distance: int): +def test_edit_distance_insertions(max_edit_distance: int): # Create a new string by slicing and concatenating def insert_char_at(s, char_to_insert, index): return s[:index] + char_to_insert + s[index:] @@ -90,14 +191,42 @@ def insert_char_at(s, char_to_insert, index): source_offset = randint(0, len(ascii_lowercase) - 1) target_offset = randint(0, len(b) - 1) b = insert_char_at(b, ascii_lowercase[source_offset], target_offset) - assert sz.levenshtein(a, b, 200) == i + 1 + assert sz.edit_distance(a, b, 200) == i + 1 @pytest.mark.repeat(1000) -def test_levenshtein_randos(): +def test_edit_distance_randos(): a = get_random_string(length=20) b = get_random_string(length=20) - assert sz.levenshtein(a, b, 200) == levenshtein(a, b) + assert sz.edit_distance(a, b, 200) == edit_distance(a, b) + + +@pytest.mark.parametrize("list_length", [10, 20, 30, 40, 50]) +@pytest.mark.parametrize("part_length", [5, 10]) +@pytest.mark.parametrize("variability", [2, 3]) +def test_fuzzy_sorting(list_length: int, part_length: int, variability: int): + native_list = [ + get_random_string(variability=variability, length=part_length) + for _ in range(list_length) + ] + native_joined = ".".join(native_list) + big_joined = Str(native_joined) + big_list = big_joined.split(".") + + native_ordered = sorted(native_list) + native_order = big_list.order() + for i in range(list_length): + assert native_ordered[i] == native_list[native_order[i]], "Order is wrong" + assert native_ordered[i] == str( + big_list[int(native_order[i])] + ), "Split is wrong?!" + + native_list.sort() + big_list.sort() + + assert len(native_list) == len(big_list) + for native_str, big_str in zip(native_list, big_list): + assert native_str == str(big_str), "Order is wrong" @pytest.mark.parametrize("list_length", [10, 20, 30, 40, 50]) diff --git a/scripts/unit_test.py b/scripts/unit_test.py deleted file mode 100644 index 369df430..00000000 --- a/scripts/unit_test.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import Union, Optional -from random import choice, randint -from string import ascii_lowercase - -import pytest - -import stringzilla as sz -from stringzilla import Str, Strs - - -def test_unit_construct(): - native = "aaaaa" - big = Str(native) - assert len(big) == len(native) - - -def test_unit_indexing(): - native = "abcdef" - big = Str(native) - for i in range(len(native)): - assert big[i] == native[i] - - -def test_unit_count(): - native = "aaaaa" - big = Str(native) - assert big.count("a") == 5 - assert big.count("aa") == 2 - assert big.count("aa", allowoverlap=True) == 4 - - -def test_unit_contains(): - big = Str("abcdef") - assert "a" in big - assert "ab" in big - assert "xxx" not in big - - -def test_unit_rich_comparisons(): - assert Str("aa") == "aa" - assert Str("aa") < "b" - assert Str("abb")[1:] == "bb" - - -def test_unit_buffer_protocol(): - import numpy as np - - my_str = Str("hello") - arr = np.array(my_str) - assert arr.dtype == np.dtype("c") - assert arr.shape == (len("hello"),) - assert "".join([c.decode("utf-8") for c in arr.tolist()]) == "hello" - - -def test_unit_split(): - native = "token1\ntoken2\ntoken3" - big = Str(native) - assert native.splitlines() == list(big.splitlines()) - assert native.splitlines(True) == list(big.splitlines(keeplinebreaks=True)) - assert native.split("token3") == list(big.split("token3")) - - words = sz.split(big, "\n") - assert len(words) == 3 - assert str(words[0]) == "token1" - assert str(words[2]) == "token3" - - parts = sz.split(big, "\n", keepseparator=True) - assert len(parts) == 3 - assert str(parts[0]) == "token1\n" - assert str(parts[2]) == "token3" - - -def test_unit_sequence(): - native = "p3\np2\np1" - big = Str(native) - - lines = big.splitlines() - assert [2, 1, 0] == list(lines.order()) - - lines.sort() - assert [0, 1, 2] == list(lines.order()) - assert ["p1", "p2", "p3"] == list(lines) - - # Reverse order - assert [2, 1, 0] == list(lines.order(reverse=True)) - lines.sort(reverse=True) - assert ["p3", "p2", "p1"] == list(lines) - - -def test_unit_globals(): - """Validates that the previously unit-tested member methods are also visible as global functions.""" - - assert sz.find("abcdef", "bcdef") == 1 - assert sz.find("abcdef", "x") == -1 - - assert sz.count("abcdef", "x") == 0 - assert sz.count("aaaaa", "a") == 5 - assert sz.count("aaaaa", "aa") == 2 - assert sz.count("aaaaa", "aa", allowoverlap=True) == 4 - - assert sz.levenshtein("aaa", "aaa") == 0 - assert sz.levenshtein("aaa", "bbb") == 3 - assert sz.levenshtein("abababab", "aaaaaaaa") == 4 - assert sz.levenshtein("abababab", "aaaaaaaa", 2) == 2 - assert sz.levenshtein("abababab", "aaaaaaaa", bound=2) == 2