Skip to content

Commit

Permalink
update math functions
Browse files Browse the repository at this point in the history
  • Loading branch information
i80287 committed Jul 20, 2024
1 parent ba94213 commit bcb0b09
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 94 deletions.
180 changes: 92 additions & 88 deletions number_theory/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,18 +163,7 @@ ATTRIBUTE_CONST inline I128_CONSTEXPR uint64_t bin_pow_mod(uint64_t n, uint64_t

ATTRIBUTE_CONST constexpr uint32_t isqrt(uint32_t n) noexcept {
uint32_t y = 0;

#if (defined(__cpp_lib_is_constant_evaluated) && __cpp_lib_is_constant_evaluated >= 201811L) || \
CONFIG_HAS_BUILTIN(__builtin_is_constant_evaluated) || \
CONFIG_HAS_BUILTIN(__builtin_constant_p)
#if defined(__cpp_lib_is_constant_evaluated) && __cpp_lib_is_constant_evaluated >= 201811L
if (std::is_constant_evaluated())
#elif CONFIG_HAS_BUILTIN(__builtin_is_constant_evaluated)
if (__builtin_is_constant_evaluated())
#else
if (__builtin_constant_p(n))
#endif
{
if (config_is_constant_evaluated()) {
/**
* See Hackers Delight Chapter 11.
*/
Expand All @@ -186,9 +175,7 @@ ATTRIBUTE_CONST constexpr uint32_t isqrt(uint32_t n) noexcept {
y |= m;
}
}
} else
#endif
{
} else {
y = static_cast<uint32_t>(std::sqrt(static_cast<double>(n)));
}

Expand Down Expand Up @@ -239,17 +226,21 @@ ATTRIBUTE_CONST inline I128_CONSTEXPR uint64_t isqrt(uint128_t n) noexcept {
#endif

ATTRIBUTE_CONST constexpr uint32_t icbrt(uint32_t n) noexcept {
/**
* See Hackers Delight Chapter 11.
*/
uint32_t y = 0;
for (int32_t s = 30; s >= 0; s -= 3) {
y *= 2;
uint32_t b = (3 * y * (y + 1) | 1) << s;
if (n >= b) {
n -= b;
y++;
if (config_is_constant_evaluated()) {
/**
* See Hackers Delight Chapter 11.
*/
for (int32_t s = 30; s >= 0; s -= 3) {
y *= 2;
uint32_t b = (3 * y * (y + 1) | 1) << s;
if (n >= b) {
n -= b;
y++;
}
}
} else {
y = static_cast<uint32_t>(std::cbrt(static_cast<double>(n)));
}
// 1625^3 = 4291015625 < 2^32 - 1 = 4294967295 < 4298942376 = 1626^3
ATTRIBUTE_ASSUME(y <= 1625u);
Expand Down Expand Up @@ -1148,21 +1139,21 @@ ATTRIBUTE_CONST constexpr uint64_t nearest_pow2_ge(uint64_t n) noexcept {
/// @brief If @a n != 0, return number that is power of 2 and
/// whose only bit is the lowest bit set in the @a n
/// Otherwise, return 0
/// @tparam TInt
/// @param[in] n
/// @return
/// @tparam TInt
/// @param[in] n
/// @return
template <class TInt>
ATTRIBUTE_CONST constexpr TInt least_bit_set(TInt n) noexcept {
namespace helper_ns =
#ifdef INTEGERS_128_BIT_HPP
type_traits_helper_int128_t;
type_traits_helper_int128_t;
#else
std;
std;
#endif
static_assert(helper_ns::is_integral_v<TInt>, "integral type expected");
using TUInt = helper_ns::make_unsigned_t<TInt>;
using TUInt = helper_ns::make_unsigned_t<TInt>;
using TUIntAtLeastUInt = std::common_type_t<TUInt, uint32_t>;
auto unsigned_n = static_cast<TUIntAtLeastUInt>(static_cast<TUInt>(n));
auto unsigned_n = static_cast<TUIntAtLeastUInt>(static_cast<TUInt>(n));
return static_cast<TInt>(unsigned_n & -unsigned_n);
}

Expand Down Expand Up @@ -1661,51 +1652,54 @@ inline void prime_divisors_to_map(NumericType n, std::map<NumericType, uint32_t>
}
}

class [[nodiscard]] Factorizer final {
public:
#if CONFIG_HAS_AT_LEAST_CXX_20 && !defined(_GLIBCXX_DEBUG)
constexpr
#define CONSTEXPR_VECTOR constexpr
#else
#define CONSTEXPR_VECTOR
#endif
Factorizer(std::uint32_t n)
: least_prime_factor(std::size_t(n) + 1) {

/// @brief https://cp-algorithms.com/algebra/prime-sieve-linear.html
class [[nodiscard]] Factorizer final {
public:
CONSTEXPR_VECTOR Factorizer(std::uint32_t n) : least_prime_factor(std::size_t(n) + 1) {
for (uint32_t i = 2; i <= n; i++) {
if (least_prime_factor[i] == 0) {
least_prime_factor[i] = i;
primes.push_back(i);
}
for (std::size_t prime_index = 0; std::size_t(primes[prime_index]) * i <= n;
prime_index++) {
for (std::size_t prime_index = 0;; prime_index++) {
const auto p = primes[prime_index];
least_prime_factor[p * i] = primes[prime_index];
const auto lpf = least_prime_factor[i];
ATTRIBUTE_ASSUME(p <= lpf);
if (p == lpf) {
const auto x = std::size_t(p) * i;
if (x > n) {
break;
}
least_prime_factor[x] = p;
// assert(p <= least_prime_factor[i]);
if (p == least_prime_factor[i]) {
break;
}
}
}
}

[[nodiscard]] constexpr const auto& sorted_primes() const noexcept {
[[nodiscard]] constexpr const auto& sorted_primes() const noexcept ATTRIBUTE_LIFETIME_BOUND {
return primes;
}
[[nodiscard]] constexpr const auto& least_prime_factors() const noexcept {
[[nodiscard]] constexpr const auto& least_prime_factors() const noexcept
ATTRIBUTE_LIFETIME_BOUND {
return least_prime_factor;
}
[[nodiscard]]
#if CONFIG_HAS_AT_LEAST_CXX_20 && !defined(_GLIBCXX_DEBUG)
constexpr
#endif
bool is_prime(std::uint32_t n) const noexcept {
[[nodiscard]] CONSTEXPR_VECTOR bool is_prime(std::uint32_t n) const noexcept {
return least_prime_factor[n] == n && n >= 2;
}
[[nodiscard]]
#if CONFIG_HAS_AT_LEAST_CXX_20 && !defined(_GLIBCXX_DEBUG)
constexpr
#endif
auto prime_factors(std::uint32_t n) const noexcept {
[[nodiscard]] CONSTEXPR_VECTOR auto prime_factors(std::uint32_t n) const noexcept {
std::vector<PrimeFactor<std::uint32_t>> pfs;
while (n >= 2) {
if (n % 2 == 0 && n > 0) {
const auto [n_div_pow_of_2, power_of_2] = extract_pow2(n);
pfs.emplace_back(std::uint32_t(2), power_of_2);
n = n_div_pow_of_2;
}
while (n >= 3) {
const auto lpf = least_prime_factor[n];
if (pfs.empty() || pfs.back().factor != lpf) {
// assert(pfs.empty() || pfs.back().factor < lpf);
Expand All @@ -1726,14 +1720,10 @@ class [[nodiscard]] Factorizer final {
};

/// @brief Find all prime numbers in [2; n]
/// @param N exclusive upper bound
/// @return vector<bool>, for which bvec[n] == true \iff n is prime
[[nodiscard]]
#if CONFIG_HAS_AT_LEAST_CXX_20 && !defined(_GLIBCXX_DEBUG)
constexpr
#endif
std::vector<bool> primes_sieve_as_bvector(uint32_t n) {
std::vector<bool> primes(std::size_t(n) + 1, true);
/// @param n inclusive upper bound
/// @return vector bvec, for which bvec[n] == true \iff n is prime
[[nodiscard]] CONSTEXPR_VECTOR auto dynamic_primes_sieve(uint32_t n) {
std::vector<std::uint8_t> primes(std::size_t(n) + 1, true);
primes[0] = false;
if (likely(n > 0)) {
primes[1] = false;
Expand All @@ -1750,41 +1740,55 @@ constexpr
return primes;
}

#undef CONSTEXPR_VECTOR

// https://en.cppreference.com/w/cpp/feature_test
#if defined(__cpp_lib_constexpr_bitset) && \
(__cpp_lib_constexpr_bitset >= 202207L || CONFIG_HAS_AT_LEAST_CXX_23)
#define CONSTEXPR_BITSET_OPS constexpr
#if defined(__cpp_constexpr) && __cpp_constexpr >= 202211L
#define CONSTEXPR_FIXED_PRIMES_SIEVE constexpr
#define CONSTEXPR_PRIMES_SIEVE constexpr
#else
#define CONSTEXPR_FIXED_PRIMES_SIEVE
#define CONSTEXPR_PRIMES_SIEVE constinit
#endif
#else
#define CONSTEXPR_BITSET_OPS
#define CONSTEXPR_FIXED_PRIMES_SIEVE
#define CONSTEXPR_PRIMES_SIEVE
#endif

/// @brief Find all prime numbers in [2; N]
/// @tparam N exclusive upper bound
/// @return bitset, for which bset[n] == true \iff n is prime
template <uint32_t N>
[[nodiscard]]
#if defined(__cpp_constexpr) && __cpp_constexpr >= 202211L && \
defined(__cpp_lib_constexpr_bitset) && __cpp_lib_constexpr_bitset >= 202207L
constexpr
#endif
const auto& primes_sieve_as_bitset() noexcept {
// https://en.cppreference.com/w/cpp/feature_test
#if defined(__cpp_constexpr) && __cpp_constexpr >= 202211L && \
defined(__cpp_lib_constexpr_bitset) && __cpp_lib_constexpr_bitset >= 202207L
constexpr
#endif
static auto primes = []() constexpr noexcept {
std::bitset<std::size_t(N) + 1> primes_bs;
if constexpr (N + 1 > 2) {
primes_bs.set();
primes_bs[0] = false;
primes_bs[1] = false;
const uint32_t root = math_functions::isqrt(N);
for (uint32_t i = 2; i <= root; i++) {
if (primes_bs[i]) {
for (std::size_t j = i * i; j <= N; j += i) {
primes_bs[j] = false;
}
[[nodiscard]] CONSTEXPR_FIXED_PRIMES_SIEVE const auto& fixed_primes_sieve() noexcept {
constexpr auto kSize = std::size_t(N) + 1;
static CONSTEXPR_PRIMES_SIEVE std::bitset<kSize> primes = []() CONSTEXPR_BITSET_OPS noexcept {
std::bitset<kSize> primes_bs{};
primes_bs.set();
primes_bs[0] = false;
if constexpr (kSize > 1) {
primes_bs[1] = false;
const uint32_t root = math_functions::isqrt(N);
for (uint32_t i = 2; i <= root; i++) {
if (primes_bs[i]) {
for (std::size_t j = i * i; j <= N; j += i) {
primes_bs[j] = false;
}
}
}
return primes_bs;
}();
}
return primes_bs;
}();
return primes;
}

#undef CONSTEXPR_PRIMES_SIEVE
#undef CONSTEXPR_FIXED_PRIMES_SIEVE
#undef CONSTEXPR_BITSET_OPS

} // namespace math_functions

#if defined(INTEGERS_128_BIT_HPP)
Expand Down
17 changes: 11 additions & 6 deletions number_theory/test_math_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -749,13 +749,18 @@ static_assert(least_bit_set(0b10000000000) == 0b10000000000, "least_bit_set");
static_assert(least_bit_set(0b100000000u) == 0b100000000u, "least_bit_set");
static_assert(least_bit_set(0b1000000000u) == 0b1000000000u, "least_bit_set");
static_assert(least_bit_set(0b10000000000u) == 0b10000000000u, "least_bit_set");
static_assert(least_bit_set(0b1000000000000000000000000000000000000000000000000000000000000000ull) == 0b1000000000000000000000000000000000000000000000000000000000000000ull, "least_bit_set");
static_assert(
least_bit_set(0b1000000000000000000000000000000000000000000000000000000000000000ull) ==
0b1000000000000000000000000000000000000000000000000000000000000000ull,
"least_bit_set");
static_assert(least_bit_set(0b110101010101010101011001) == 0b1, "least_bit_set");
static_assert(least_bit_set(0b1010101011001101011100010100000ll) == 0b100000ll, "least_bit_set");
static_assert(least_bit_set(0b1010111001010101101010110101001101011100110011000ll) == 0b1000ll, "least_bit_set");
static_assert(least_bit_set(0b1010111001010101101010110101001101011100110011000ll) == 0b1000ll,
"least_bit_set");
static_assert(least_bit_set(0b110101010101010101011001u) == 0b1u, "least_bit_set");
static_assert(least_bit_set(0b1010101011001101011100010100000llu) == 0b100000llu, "least_bit_set");
static_assert(least_bit_set(0b1010111001010101101010110101001101011100110011000llu) == 0b1000llu, "least_bit_set");
static_assert(least_bit_set(0b1010111001010101101010110101001101011100110011000llu) == 0b1000llu,
"least_bit_set");

static_assert(log2_floor(uint32_t(0)) == uint32_t(-1), "log2_floor");
static_assert(log2_floor(uint32_t(1)) == 0, "log2_floor");
Expand Down Expand Up @@ -1266,8 +1271,8 @@ static void test_prime_bitarrays() {
log_tests_started();

constexpr size_t N = 1000;
std::vector<bool> primes_as_bvector = math_functions::primes_sieve_as_bvector(N);
const std::bitset<N + 1>& primes_bset = math_functions::primes_sieve_as_bitset<N>();
std::vector primes_as_bvector = math_functions::dynamic_primes_sieve(N);
const std::bitset<N + 1>& primes_bset = math_functions::fixed_primes_sieve<N>();
constexpr uint32_t primes[] = {
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59,
61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139,
Expand Down Expand Up @@ -1298,7 +1303,7 @@ static void test_factorizer() {

constexpr auto N = uint32_t(1e7);
Factorizer fact(N);
const auto is_prime = primes_sieve_as_bvector(N);
const auto is_prime = dynamic_primes_sieve(N);
assert(is_prime.size() == N + 1);
for (std::uint32_t i = 0; i <= N; i++) {
assert(is_prime[i] == fact.is_prime(i));
Expand Down

0 comments on commit bcb0b09

Please sign in to comment.