Skip to content

Commit

Permalink
update longint.hpp and add flag for longint asserts to the CMakeLists…
Browse files Browse the repository at this point in the history
….txt
  • Loading branch information
i80287 committed Oct 24, 2024
1 parent 23c9bff commit 00f91f6
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 10 deletions.
79 changes: 70 additions & 9 deletions number_theory/longint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
#error "Current implementation works only with GCC"
#endif

#if defined(ENABLE_LONGINT_DEBUG_ASSERTS) && ENABLE_LONGINT_DEBUG_ASSERTS
#define LONGINT_DEBUG_ASSERT(expr) assert(expr)
#else
#define LONGINT_DEBUG_ASSERT(expr)
#endif

namespace longint_allocator {

// #define DEBUG_LI_ALLOC_PRINTING 1
Expand Down Expand Up @@ -424,7 +430,7 @@ struct longint {
ATTRIBUTE_ALWAYS_INLINE ATTRIBUTE_PURE constexpr ssize_type size() const noexcept {
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wtype-limits"
#pragma clang diagnostic ignored "-Wtautological-type-limit-compare"
#endif
const ssize_type value = size_;
if (value > static_cast<ssize_type>(max_size())) {
Expand Down Expand Up @@ -667,11 +673,11 @@ struct longint {
const bool find_sum = (size_ ^ other.size_) >= 0;
const size_type usize2 = other.usize();
static_assert(max_size() + 1 > max_size());
const size_type min_size_boundary = std::max(usize(), usize2) + (find_sum ? 1 : 0);
ATTRIBUTE_ASSUME(usize() < min_size_boundary);
const size_type usize1 = set_size_at_least(min_size_boundary);
const size_type usize1 = set_size_at_least(std::max(usize(), usize2) + (find_sum ? 1 : 0));
LONGINT_DEBUG_ASSERT(usize1 >= usize2);
ATTRIBUTE_ASSUME(usize1 >= usize2);
if (find_sum) {
LONGINT_DEBUG_ASSERT(usize1 > usize2);
ATTRIBUTE_ASSUME(usize1 > usize2);
longint_add_with_free_space(nums_, usize1, other.nums_, usize2);
} else {
Expand Down Expand Up @@ -1231,8 +1237,15 @@ struct longint {
break;
}

#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wtautological-type-limit-compare"
#endif
static_assert(math_functions::nearest_greater_equal_power_of_two(max_size()) <=
std::numeric_limits<size_t>::max());
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
const size_t n = size_t{math_functions::nearest_greater_equal_power_of_two(usize_value)};
ensureBinBasePowsCapacity(math_functions::log2_floor(n));
digit_t* const knums = allocate(n);
Expand Down Expand Up @@ -1421,6 +1434,7 @@ struct longint {
set_zero();
return *this;
}
LONGINT_DEBUG_ASSERT(1 <= m && m <= k);
ATTRIBUTE_ASSUME(1 <= m && m <= k);
const dec_size_type new_size = m + k;
if (m <= 16 || m * k <= 1024) {
Expand Down Expand Up @@ -1545,6 +1559,7 @@ struct longint {
const dec_size_type m_size,
const dec_digit_t k_digits[],
const dec_size_type k_size, Decimal& product_result) {
LONGINT_DEBUG_ASSERT(m_size <= k_size);
ATTRIBUTE_ASSUME(m_size <= k_size);
const dec_size_type new_size = m_size + k_size;
dec_digit_t* const ans = allocate(new_size);
Expand Down Expand Up @@ -1660,6 +1675,7 @@ struct longint {
ATTRIBUTE_CONST
ATTRIBUTE_ALWAYS_INLINE
static constexpr dec_size_type polys_size(const dec_size_type size_value) noexcept {
LONGINT_DEBUG_ASSERT(size_value <= kMaxDecFFTSize);
ATTRIBUTE_ASSUME(size_value <= kMaxDecFFTSize);
static_assert(3 * kMaxDecFFTSize > kMaxDecFFTSize);
static_assert(
Expand Down Expand Up @@ -1817,6 +1833,7 @@ struct longint {
ATTRIBUTE_ALWAYS_INLINE
constexpr dec_size_type poly_size() const noexcept {
const auto value = poly_size_;
LONGINT_DEBUG_ASSERT(math_functions::is_power_of_two(value));
ATTRIBUTE_ASSUME(math_functions::is_power_of_two(value));
return value;
}
Expand Down Expand Up @@ -1857,6 +1874,7 @@ struct longint {
static constexpr void multiply_and_store_to(const digit_t m_ptr[], const size_type m,
const digit_t k_ptr[], const size_type k,
digit_t* const ans) noexcept {
LONGINT_DEBUG_ASSERT(m <= k);
ATTRIBUTE_ASSUME(m <= k);
digit_t* ans_store_ptr = ans;
for (size_type j = 0; j < m; ans_store_ptr++, j++) {
Expand Down Expand Up @@ -1888,6 +1906,7 @@ struct longint {
poly_size_type n = 2 * math_functions::nearest_greater_equal_power_of_two(product_size);
const bool need_high_precision = n > kFFTPrecisionBorder;
n <<= need_high_precision;
LONGINT_DEBUG_ASSERT(math_functions::is_power_of_two(n));
ATTRIBUTE_ASSUME(math_functions::is_power_of_two(n));
return {n, need_high_precision};
}
Expand Down Expand Up @@ -1940,14 +1959,23 @@ struct longint {
const digit_t k_ptr[], const size_type k,
fft::complex* p, const poly_size_type n,
bool need_high_precision) noexcept {
LONGINT_DEBUG_ASSERT(0 < m);
ATTRIBUTE_ASSUME(0 < m);
LONGINT_DEBUG_ASSERT(m <= k);
ATTRIBUTE_ASSUME(m <= k);
LONGINT_DEBUG_ASSERT(m <= max_size());
ATTRIBUTE_ASSUME(m <= max_size());
LONGINT_DEBUG_ASSERT(k <= max_size());
ATTRIBUTE_ASSUME(k <= max_size());
LONGINT_DEBUG_ASSERT(m + k <= max_size());
ATTRIBUTE_ASSUME(m + k <= max_size());
LONGINT_DEBUG_ASSERT(m + k <= n);
ATTRIBUTE_ASSUME(m + k <= n);
LONGINT_DEBUG_ASSERT(need_high_precision || n <= kFFTPrecisionBorder);
ATTRIBUTE_ASSUME(need_high_precision || n <= kFFTPrecisionBorder);
LONGINT_DEBUG_ASSERT(!need_high_precision || n > kFFTPrecisionBorder * 2);
ATTRIBUTE_ASSUME(!need_high_precision || n > kFFTPrecisionBorder * 2);
LONGINT_DEBUG_ASSERT(math_functions::is_power_of_two(n));
ATTRIBUTE_ASSUME(math_functions::is_power_of_two(n));

static_assert(kNumsBits == 32);
Expand Down Expand Up @@ -2024,11 +2052,17 @@ struct longint {
const size_type nums_size, fft::complex* p,
const poly_size_type n,
bool need_high_precision) noexcept {
LONGINT_DEBUG_ASSERT(0 < nums_size);
ATTRIBUTE_ASSUME(0 < nums_size);
LONGINT_DEBUG_ASSERT(nums_size <= max_size());
ATTRIBUTE_ASSUME(nums_size <= max_size());
LONGINT_DEBUG_ASSERT(nums_size * 2 <= n);
ATTRIBUTE_ASSUME(nums_size * 2 <= n);
LONGINT_DEBUG_ASSERT(need_high_precision || n <= kFFTPrecisionBorder);
ATTRIBUTE_ASSUME(need_high_precision || n <= kFFTPrecisionBorder);
LONGINT_DEBUG_ASSERT(!need_high_precision || n > kFFTPrecisionBorder * 2);
ATTRIBUTE_ASSUME(!need_high_precision || n > kFFTPrecisionBorder * 2);
LONGINT_DEBUG_ASSERT(math_functions::is_power_of_two(n));
ATTRIBUTE_ASSUME(math_functions::is_power_of_two(n));

static_assert(kNumsBits == 32);
Expand Down Expand Up @@ -2130,13 +2164,13 @@ struct longint {
const digit_t* const u = nums_;
const digit_t* const v = other.nums_;
const digit_t last_v_num = v[n - 1];
assert(last_v_num > 0);
LONGINT_DEBUG_ASSERT(last_v_num > 0);
ATTRIBUTE_ASSUME(last_v_num > 0);
static_assert(kNumsBits == 32);
// 0 <= s < kNumsBits
const auto s = static_cast<std::uint32_t>(math_functions::countl_zero(last_v_num));
longint::divmod_normalize_vn(vn, v, n, s);
assert(vn[n - 1] >= digit_t{1} << (kNumsBits - 1));
LONGINT_DEBUG_ASSERT(vn[n - 1] >= digit_t{1} << (kNumsBits - 1));
ATTRIBUTE_ASSUME(vn[n - 1] >= digit_t{1} << (kNumsBits - 1));
longint::divmod_normalize_un(un, u, m, m + 1, s);
longint::divmod_impl_unchecked(
Expand All @@ -2163,26 +2197,32 @@ struct longint {
const digit_t* RESTRICT_QUALIFIER const vn,
const size_type vn_size,
digit_t* RESTRICT_QUALIFIER const quot) noexcept {
LONGINT_DEBUG_ASSERT(vn_size >= 2);
ATTRIBUTE_ASSUME(vn_size >= 2);
LONGINT_DEBUG_ASSERT(un_size > vn_size);
ATTRIBUTE_ASSUME(un_size > vn_size);
for (size_type j = un_size - vn_size - 1; static_cast<ssize_type>(j) >= 0; j--) {
// Compute estimate qhat of q[j].
const double_digit_t cur =
(double_digit_t{un[j + vn_size]} << kNumsBits) | un[j + vn_size - 1];
const digit_t last_vn = vn[vn_size - 1];
LONGINT_DEBUG_ASSERT(last_vn >= digit_t{1} << (kNumsBits - 1));
ATTRIBUTE_ASSUME(last_vn >= digit_t{1} << (kNumsBits - 1));
double_digit_t qhat = cur / last_vn;
double_digit_t rhat = cur % last_vn;
LONGINT_DEBUG_ASSERT(qhat * last_vn + rhat == cur);
ATTRIBUTE_ASSUME(qhat * last_vn + rhat == cur);
while (qhat >= kNumsBase ||
qhat * vn[vn_size - 2] > kNumsBase * rhat + un[j + vn_size - 2]) {
qhat--;
rhat += last_vn;
LONGINT_DEBUG_ASSERT(qhat * last_vn + rhat == cur);
ATTRIBUTE_ASSUME(qhat * last_vn + rhat == cur);
if (rhat >= kNumsBase) {
break;
}
}
LONGINT_DEBUG_ASSERT(qhat * last_vn + rhat == cur);
ATTRIBUTE_ASSUME(qhat * last_vn + rhat == cur);
// Multiply and subtract
double_digit_t t = divmod_mult_sub(un + j, vn, vn_size, qhat);
Expand Down Expand Up @@ -2213,6 +2253,7 @@ struct longint {
const digit_t* RESTRICT_QUALIFIER const vn,
const size_type vn_size,
const double_digit_t qhat) noexcept {
LONGINT_DEBUG_ASSERT(vn_size >= 2);
ATTRIBUTE_ASSUME(vn_size >= 2);
double_digit_t carry = 0;
for (size_type i = 0; i < vn_size; i++) {
Expand All @@ -2233,6 +2274,7 @@ struct longint {
static constexpr void divmod_add_back(digit_t* RESTRICT_QUALIFIER const un,
const digit_t* RESTRICT_QUALIFIER const vn,
const size_type vn_size) noexcept {
LONGINT_DEBUG_ASSERT(vn_size >= 2);
ATTRIBUTE_ASSUME(vn_size >= 2);
double_digit_t carry = 0;
for (size_type i = 0; i < vn_size; i++) {
Expand All @@ -2249,7 +2291,9 @@ struct longint {
ATTRIBUTE_ALWAYS_INLINE
static constexpr void divmod_normalize_vn(digit_t vn[], const digit_t v[], size_type n,
std::uint32_t s) noexcept {
LONGINT_DEBUG_ASSERT(n > 1);
ATTRIBUTE_ASSUME(n > 1);
LONGINT_DEBUG_ASSERT(s < 32);
ATTRIBUTE_ASSUME(s < 32);
for (size_type i = n - 1; i > 0; i--) {
vn[i] = (v[i] << s) | static_cast<digit_t>(double_digit_t{v[i - 1]} >> (kNumsBits - s));
Expand All @@ -2264,8 +2308,11 @@ struct longint {
static constexpr void divmod_normalize_un(digit_t un[], const digit_t u[], size_type m,
ATTRIBUTE_MAYBE_UNUSED size_type m_plus_one,
std::uint32_t s) noexcept {
LONGINT_DEBUG_ASSERT(m > 1);
ATTRIBUTE_ASSUME(m > 1);
LONGINT_DEBUG_ASSERT(s < 32);
ATTRIBUTE_ASSUME(s < 32);
LONGINT_DEBUG_ASSERT(m + 1 == m_plus_one);
ATTRIBUTE_ASSUME(m + 1 == m_plus_one);
un[m] = static_cast<digit_t>(double_digit_t{u[m - 1]} >> (kNumsBits - s));
for (size_type i = m - 1; i > 0; i--) {
Expand All @@ -2282,8 +2329,11 @@ struct longint {
size_type n,
ATTRIBUTE_MAYBE_UNUSED size_type n_plus_one,
std::uint32_t s) noexcept {
LONGINT_DEBUG_ASSERT(n > 1);
ATTRIBUTE_ASSUME(n > 1);
LONGINT_DEBUG_ASSERT(s < 32);
ATTRIBUTE_ASSUME(s < 32);
LONGINT_DEBUG_ASSERT(n + 1 == n_plus_one);
ATTRIBUTE_ASSUME(n + 1 == n_plus_one);
for (size_type i = 0; i < n; i++) {
rem[i] =
Expand Down Expand Up @@ -2339,6 +2389,7 @@ struct longint {
static void convertDecBaseMultAdd(digit_t conv_digits[], const size_type conv_len,
const longint& conv_base_pow, digit_t mult_add_buffer[],
fft::complex fft_poly_buffer[]) {
LONGINT_DEBUG_ASSERT(0 < conv_base_pow.size_);
ATTRIBUTE_ASSUME(0 < conv_base_pow.size_);
const size_type m_size = conv_base_pow.usize();
const digit_t* const m_ptr = conv_base_pow.nums_;
Expand All @@ -2356,9 +2407,13 @@ struct longint {
digit_t mult_add_buffer[],
fft::complex fft_poly_buffer[]) {
const size_type half_conv_len = conv_len / 2;
LONGINT_DEBUG_ASSERT(0 < m_size);
ATTRIBUTE_ASSUME(0 < m_size);
LONGINT_DEBUG_ASSERT(m_size <= half_conv_len);
ATTRIBUTE_ASSUME(m_size <= half_conv_len);
LONGINT_DEBUG_ASSERT(math_functions::is_power_of_two(half_conv_len));
ATTRIBUTE_ASSUME(math_functions::is_power_of_two(half_conv_len));
LONGINT_DEBUG_ASSERT(conv_len <= max_size());
ATTRIBUTE_ASSUME(conv_len <= max_size());
const digit_t* num_hi = conv_digits + half_conv_len;
static_assert(max_size() + max_size() > max_size());
Expand Down Expand Up @@ -2396,6 +2451,7 @@ struct longint {

ATTRIBUTE_SIZED_ACCESS(read_only, 1, 2)
static Decimal convertBinBaseImpl(const digit_t nums[], const std::size_t size) {
LONGINT_DEBUG_ASSERT(math_functions::is_power_of_two(size));
ATTRIBUTE_ASSUME(math_functions::is_power_of_two(size));
switch (size) {
case 0:
Expand All @@ -2411,7 +2467,7 @@ struct longint {
Decimal high_dec = convertBinBaseImpl(nums + size / 2, size / 2);

const uint32_t idx = math_functions::log2_floor(size) - 1;
assert(idx < conv_bin_base_pows.size());
LONGINT_DEBUG_ASSERT(idx < conv_bin_base_pows.size());
high_dec *= conv_bin_base_pows[idx];
high_dec += low_dec;
return high_dec;
Expand All @@ -2421,7 +2477,7 @@ struct longint {
ATTRIBUTE_ALWAYS_INLINE
ATTRIBUTE_SIZED_ACCESS(read_only, 1, 2)
static Decimal convertBinBase(const digit_t nums[], const std::size_t size) {
assert(math_functions::is_power_of_two(size));
LONGINT_DEBUG_ASSERT(math_functions::is_power_of_two(size));
return convertBinBaseImpl(nums, size);
}

Expand All @@ -2441,12 +2497,12 @@ struct longint {
const size_type current_capacity = capacity();
static_assert(max_size() * 2 > max_size());
const size_type new_capacity = (current_capacity * 2) | (current_capacity == 0);
LONGINT_DEBUG_ASSERT(capacity_ < new_capacity);
ATTRIBUTE_ASSUME(capacity_ < new_capacity);
reserve(new_capacity);
}

ATTRIBUTE_NODISCARD_WITH_MESSAGE("impl error")
ATTRIBUTE_ALWAYS_INLINE
size_type set_size_at_least(size_type new_size) {
size_type cur_size = usize();
if (new_size <= cur_size) {
Expand Down Expand Up @@ -2539,6 +2595,7 @@ struct longint {
static constexpr void longint_add_with_free_space(digit_t lhs[], const size_type lhs_size,
const digit_t rhs[],
const size_type rhs_size) noexcept {
LONGINT_DEBUG_ASSERT(lhs_size > rhs_size);
ATTRIBUTE_ASSUME(lhs_size > rhs_size);
double_digit_t carry = 0;
const digit_t* const lhs_end = lhs + lhs_size;
Expand All @@ -2563,6 +2620,7 @@ struct longint {
static constexpr bool longint_subtract_with_free_space(digit_t lhs[], const size_type lhs_size,
const digit_t rhs[],
const size_type rhs_size) noexcept {
LONGINT_DEBUG_ASSERT(lhs_size >= rhs_size);
ATTRIBUTE_ASSUME(lhs_size >= rhs_size);
bool overflowed = longint_subtract_with_carry(lhs, lhs_size, rhs, rhs_size);
if (overflowed) {
Expand All @@ -2580,6 +2638,7 @@ struct longint {
static constexpr bool longint_subtract_with_carry(digit_t lhs[], const size_type lhs_size,
const digit_t rhs[],
const size_type rhs_size) noexcept {
LONGINT_DEBUG_ASSERT(lhs_size >= rhs_size);
ATTRIBUTE_ASSUME(lhs_size >= rhs_size);
const digit_t* const lhs_end = lhs + lhs_size;
const digit_t* const rhs_end = rhs + rhs_size;
Expand Down Expand Up @@ -2795,6 +2854,7 @@ inline void longint::set_str_impl(const unsigned char* str, const std::size_t st
(digits_count + kStrConvBaseDigits - 1) / kStrConvBaseDigits;
const size_type aligned_str_conv_digits_size = check_size(
math_functions::nearest_greater_equal_power_of_two(uint64_t{str_conv_digits_size}));
LONGINT_DEBUG_ASSERT(str_conv_digits_size <= aligned_str_conv_digits_size);
ATTRIBUTE_ASSUME(str_conv_digits_size <= aligned_str_conv_digits_size);
reserveUninitializedWithoutCopy(aligned_str_conv_digits_size);
digit_t* str_conv_digits = nums_;
Expand Down Expand Up @@ -2855,6 +2915,7 @@ inline void longint::set_str_impl(const unsigned char* str, const std::size_t st
static_assert(max_size() * 2 > max_size());
for (size_type conv_len = 2; conv_len <= aligned_str_conv_digits_size;
conv_len *= 2, ++conv_dec_base_pows_iter) {
LONGINT_DEBUG_ASSERT(math_functions::is_power_of_two(conv_len));
ATTRIBUTE_ASSUME(math_functions::is_power_of_two(conv_len));
for (size_type pos = 0; pos < aligned_str_conv_digits_size; pos += conv_len) {
convertDecBaseMultAdd(str_conv_digits + pos, conv_len, *conv_dec_base_pows_iter,
Expand Down
3 changes: 2 additions & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ function(configure_gcc_or_clang_gcc_options)
-U_GLIBCXX_USE_DEPRECATED)
set(LOCAL_FN_TEST_COMPILE_DEFINITIONS
${LOCAL_FN_TEST_COMPILE_DEFINITIONS}
_GLIBCXX_SANITIZE_VECTOR)
_GLIBCXX_SANITIZE_VECTOR
ENABLE_LONGINT_DEBUG_ASSERTS=1)
if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 10.0)
# In gcc with version < 10.0 these checks break `constexpr`-tivity of some std:: functions
set(LOCAL_FN_TEST_COMPILE_DEFINITIONS
Expand Down

0 comments on commit 00f91f6

Please sign in to comment.