Skip to content

Commit

Permalink
better ntt and mint (#170)
Browse files Browse the repository at this point in the history
* better ntt and mint

* removi a.cpp

* better mint

* assert mint
  • Loading branch information
joaomarcosth9 authored Sep 11, 2024
1 parent e3a7ce1 commit 1ec8361
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 53 deletions.
62 changes: 30 additions & 32 deletions Codigos/Matemática/NTT/NTT-Big-Modulo/big_ntt.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
template <int MOD, typename T = Mint<MOD>>
template <auto MOD, typename T = Mint<MOD>>
void ntt(vector<T> &a, bool inv = 0) {
int n = (int)a.size();
auto b = a;
Expand All @@ -23,25 +23,17 @@ void ntt(vector<T> &a, bool inv = 0) {
}
}

template <int MOD>
vector<int> multiply(vector<int> &ta, vector<int> &tb) {
using T = Mint<MOD>;
int n = (int)ta.size(), m = (int)tb.size();
template <auto MOD, typename T = Mint<MOD>>
vector<T> multiply(vector<T> a, vector<T> b) {
int n = (int)a.size(), m = (int)b.size();
int t = n + m - 1, sz = 1;
while (sz < t) sz <<= 1;

vector<T> a(sz), b(sz), c(sz);
for (int i = 0; i < n; i++) a[i] = ta[i];
for (int i = 0; i < m; i++) b[i] = tb[i];

a.resize(sz), b.resize(sz);
ntt<MOD>(a, 0), ntt<MOD>(b, 0);
for (int i = 0; i < sz; i++) c[i] = a[i] * b[i];
ntt<MOD>(c, 1);

vector<int> res(sz);
for (int i = 0; i < sz; i++) res[i] = c[i].v;
while ((int)res.size() > t && res.back() == 0) res.pop_back();
return res;
for (int i = 0; i < sz; i++) a[i] *= b[i];
ntt<MOD>(a, 1);
while ((int)a.size() > t) a.pop_back();
return a;
}

ll extended_gcd(ll a, ll b, ll &x, ll &y) {
Expand All @@ -58,23 +50,29 @@ ll extended_gcd(ll a, ll b, ll &x, ll &y) {

ll crt(array<int, 2> rem, array<int, 2> mod) {
__int128 ans = rem[0], m = mod[0];
for (int i = 1; i < 2; i++) {
ll x, y;
ll g = extended_gcd(mod[i], (ll)m, x, y);
if ((ans - rem[i]) % g != 0) return -1;
ans = ans + (__int128)1 * (rem[i] - ans) * (m / g) * y;
m = (__int128)(mod[i] / g) * (m / g) * g;
ans = (ans % m + m) % m;
}
ll x, y;
ll g = extended_gcd(mod[1], (ll)m, x, y);
if ((ans - rem[1]) % g != 0) return -1;
ans = ans + (__int128)1 * (rem[1] - ans) * (m / g) * y;
m = (__int128)(mod[1] / g) * (m / g) * g;
ans = (ans % m + m) % m;
return (ll)ans;
}

vector<ll> big_multiply(vector<int> &a, vector<int> &b) {
const int MOD1 = 1004535809;
const int MOD2 = 1092616193;
vector<int> c1 = multiply<MOD1>(a, b);
vector<int> c2 = multiply<MOD2>(a, b);
template <auto MOD1, auto MOD2, typename T = Mint<MOD1>, typename U = Mint<MOD2>>
vector<ll> big_multiply(vector<ll> ta, vector<ll> tb) {
vector<T> a1(ta.size()), b1(tb.size());
vector<U> a2(ta.size()), b2(tb.size());
for (int i = 0; i < (int)ta.size(); i++) a1[i] = ta[i];
for (int i = 0; i < (int)tb.size(); i++) b1[i] = tb[i];
for (int i = 0; i < (int)ta.size(); i++) a2[i] = ta[i];
for (int i = 0; i < (int)tb.size(); i++) b2[i] = tb[i];
auto c1 = multiply<MOD1>(a1, b1);
vector<ll> res(c1.size());
for (int i = 0; i < (int)res.size(); i++) res[i] = crt({c1[i], c2[i]}, {MOD1, MOD2});
for (int i = 0; i < (int)res.size(); i++)
res[i] = crt({c1[i].v, c2[i].v}, {MOD1, MOD2});
return res;
}
}

const int MOD1 = 1004535809;
const int MOD2 = 1092616193;
2 changes: 1 addition & 1 deletion Codigos/Matemática/NTT/NTT/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# [NTT Big Modulo](big_ntt.cpp)
# [NTT](ntt.cpp)

Computa a multiplicação de polinômios com coeficientes inteiros módulo um número primo em $\mathcal{O}(N \cdot \log N)$. Exatamente o mesmo algoritmo da FFT, mas com inteiros.

Expand Down
26 changes: 9 additions & 17 deletions Codigos/Matemática/NTT/NTT/ntt.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
template <int MOD, typename T = Mint<MOD>>
template <auto MOD, typename T = Mint<MOD>>
void ntt(vector<T> &a, bool inv = 0) {
int n = (int)a.size();
auto b = a;
Expand All @@ -23,23 +23,15 @@ void ntt(vector<T> &a, bool inv = 0) {
}
}

template <int MOD>
vector<int> multiply(vector<int> &ta, vector<int> &tb) {
using T = Mint<MOD>;
int n = (int)ta.size(), m = (int)tb.size();
template <auto MOD, typename T = Mint<MOD>>
vector<T> multiply(vector<T> a, vector<T> b) {
int n = (int)a.size(), m = (int)b.size();
int t = n + m - 1, sz = 1;
while (sz < t) sz <<= 1;

vector<T> a(sz), b(sz), c(sz);
for (int i = 0; i < n; i++) a[i] = ta[i];
for (int i = 0; i < m; i++) b[i] = tb[i];

a.resize(sz), b.resize(sz);
ntt<MOD>(a, 0), ntt<MOD>(b, 0);
for (int i = 0; i < sz; i++) c[i] = a[i] * b[i];
ntt<MOD>(c, 1);

vector<int> res(sz);
for (int i = 0; i < sz; i++) res[i] = c[i].v;
while ((int)res.size() > t && res.back() == 0) res.pop_back();
return res;
for (int i = 0; i < sz; i++) a[i] *= b[i];
ntt<MOD>(a, 1);
while ((int)a.size() > t) a.pop_back();
return a;
}
3 changes: 3 additions & 0 deletions Codigos/Matemática/NTT/Taylor-Shift/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# [Taylor Shift](taylor_shift.cpp)

Usa NTT para computar o polinômio $p(x + k)$, dados $p$ e $k$. A complexidade é $O(n \log n)$.
18 changes: 18 additions & 0 deletions Codigos/Matemática/NTT/Taylor-Shift/taylor_shift.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
template <auto MOD, typename T = Mint<MOD>>
vector<T> shift(vector<T> a, int k) {
int n = (int)a.size();
vector<T> fat(n, 1), ifat(n), shifting(n);
for (int i = 1; i < n; i++) fat[i] = fat[i - 1] * i;
ifat[n - 1] = T(1) / fat[n - 1];
for (int i = n - 1; i > 0; i--) ifat[i - 1] = ifat[i] * i;
for (int i = 0; i < n; i++) a[i] *= fat[i];
T pk = 1;
for (int i = 0; i < n; i++) {
shifting[n - i - 1] = pk * ifat[i];
pk *= k;
}
auto ans = multiply<MOD>(a, shifting);
ans.erase(ans.begin(), ans.begin() + n - 1);
for (int i = 0; i < n; i++) ans[i] *= ifat[i];
return ans;
}
14 changes: 11 additions & 3 deletions Codigos/Primitivas/Modular-Int/mint.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
// se o modulo for long long, usar U = __int128
template <auto MOD, typename T = decltype(MOD), typename U = ll>
template <auto MOD, typename T = decltype(MOD)>
struct Mint {
using U = long long;
// se o modulo for long long, usar U = __int128
using m = Mint<MOD, T>;
T v;
Mint(T val = 0) : v(val) {
assert(sizeof(T) * 2 <= sizeof(U));
if (v < -MOD || v >= 2 * MOD) v %= MOD;
if (v < 0) v += MOD;
if (v >= MOD) v -= MOD;
}
Mint(U val) : v(T(val % MOD)) {
assert(sizeof(T) * 2 <= sizeof(U));
if (v < 0) v += MOD;
}
bool operator==(m o) const { return v == o.v; }
Expand Down Expand Up @@ -38,4 +46,4 @@ struct Mint {
friend m operator*(m a, m b) { return a *= b; }
friend m operator/(m a, m b) { return a /= b; }
friend m operator^(m a, U e) { return a.pwr(a, e); }
};
};

0 comments on commit 1ec8361

Please sign in to comment.