diff --git a/test/buffers_utest.cpp b/test/buffers_utest.cpp index db22e312..8002eee3 100644 --- a/test/buffers_utest.cpp +++ b/test/buffers_utest.cpp @@ -34,6 +34,7 @@ #include "quadiron.h" namespace vec = quadiron::vec; +namespace gf = quadiron::gf; template class BuffersTest : public ::testing::Test { @@ -58,6 +59,73 @@ class BuffersTest : public ::testing::Test { return vec; } + + std::unique_ptr> + gen_rand_vector(gf::RingModN& gf, size_t n, T max_val) + { + std::vector mem; + mem.reserve(max_val); + + for (size_t i = 0; i < max_val; i++) { + mem.push_back(i); + } + std::random_shuffle(mem.begin(), mem.end()); + + auto vec = std::unique_ptr>(new vec::Vector(gf, n)); + for (size_t i = 0; i < n; i++) { + vec->set(i, mem[i]); + } + return vec; + } + + bool check_eq(const T* buf1, const T* buf2, size_t len) + { + return memcmp(buf1, buf2, len * sizeof(T)) == 0; + } + + bool check_all_zeros(const T* buf, size_t len) + { + for (size_t i = 0; i < len; ++i) { + if (buf[i] != 0) { + return false; + } + } + return true; + } + + bool check_shuffled_bufs( + const vec::Buffers& input, + const vec::Buffers& output, + const vec::Vector& map) + { + const size_t input_len = input.get_n(); + const size_t output_len = output.get_n(); + const size_t map_len = map.get_n(); + + const size_t size = input.get_size(); + + std::vector check(output.get_n(), false); + + for (unsigned i = 0; i < map_len; ++i) { + if (!check_eq(input.get(i), output.get(map.get(i)), size)) { + return false; + } + check[map.get(i)] = true; + } + + // Check zero-extended if it's necessary + if (output_len > input_len) { + for (size_t i = 0; i < output_len; ++i) { + if (!check[i]) { + if (!check_all_zeros(output.get(i), size)) { + return false; + } + } + } + } + + return true; + } }; using TestedTypes = ::testing::Types; @@ -68,7 +136,7 @@ TYPED_TEST(BuffersTest, TestConstructors) // NOLINT const int n = 16; const int begin = 5; const int end = 12; - const int size = 32; + const int size = 4; auto vec1 = this->gen_buffers_rand_data(n, size); vec::Buffers vec2(*vec1, begin, end); @@ -92,6 +160,20 @@ TYPED_TEST(BuffersTest, TestConstructors) // NOLINT vec::Buffers vec3(end - begin, size, mem3); ASSERT_EQ(vec2, vec3); + + auto gf(gf::create>(65537)); + + // no-extension + const int out_n_1 = n - 5; + auto map_1 = this->gen_rand_vector(gf, out_n_1, out_n_1); + vec::Buffers vec4(*vec1, *map_1, out_n_1); + ASSERT_TRUE(this->check_shuffled_bufs(*vec1, vec4, *map_1)); + + // extension + const int out_n_2 = n + 10; + auto map_2 = this->gen_rand_vector(gf, n, out_n_2); + vec::Buffers vec5(*vec1, *map_2, out_n_2); + ASSERT_TRUE(this->check_shuffled_bufs(*vec1, vec5, *map_2)); } TYPED_TEST(BuffersTest, TestEvenOddSeparation) // NOLINT