简介

FFT / NTT / MTT 均基于傅里叶变换从而实现了对于多项式乘法的加速,它将多项式乘法的时间复杂度从 O(n2)O(n^2) 降到了 O(nlogn)O(n\log{n})

三者的区别

FFT:快速傅里叶变换,用于浮点数多项式卷积。

NTT:快速数论变换,用于带模数多项式卷积,要求模数必须为 NTT 模数。

MTT:任意模数快速数论变换,基于 NTT,小范围内允许任意模数。

模板题链接

下方模板自取,相关证明见下方链接相关题解。

洛谷 P3803 【模板】多项式乘法

洛谷 P4245 【模板】任意模数多项式乘法

FFT 快速傅里叶变换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// 快速傅里叶变换(FFT). 时间: O(nlogn)
const int MAX = 100005;
const double PI = acos(-1);

// 复数类
struct Complex {
double real;
double img;
Complex(double real = 0, double img = 0) { this->real = real, this->img = img; }
Complex operator+(Complex other) { return Complex(real + other.real, img + other.img); }
Complex operator-(Complex other) { return Complex(real - other.real, img - other.img); }
Complex operator*(Complex other) { return Complex(real * other.real - img * other.img, real * other.img + img * other.real); }
Complex operator/(double other) { return Complex(real / other, img / other); }
};

// 快速傅里叶变换. mode: 1: 变换, -1: 逆变换
int tax[MAX << 2]; // 最好开 4 倍内存, 后同, 严格内存: O(2^ceil(log2(n+m-1)))
void _FFT(Complex A[], int lim, int mode) {
for (int i = 0; i < lim; i++)
if (i < tax[i])
swap(A[i], A[tax[i]]);
for (int mid = 1; mid < lim; mid <<= 1) {
Complex Wn(cos(PI / mid), mode * sin(PI / mid));
for (int R = mid << 1, j = 0; j < lim; j += R) {
Complex w(1, 0);
for (int k = 0; k < mid; k++, w = w * Wn) {
Complex x = A[j + k], y = w * A[j + mid + k];
A[j + k] = x + y;
A[j + mid + k] = x - y;
}
}
}
}

// 快速傅里叶卷积. 时间: O((n + m)log(n + m))
// 请保证 res 数组长度 >= n + m - 1
int FFT(int a[], int n, int b[], int m, int res[]) {
static Complex A[MAX << 2], B[MAX << 2];
for (int i = 0; i < n; i++)
A[i] = Complex(a[i]);
for (int i = 0; i < m; i++)
B[i] = Complex(b[i]);
int l = 0, lim = 1;
while (lim < n + m - 1) {
lim <<= 1;
l++;
}
for (int i = 0; i < lim; i++)
tax[i] = (tax[i >> 1] >> 1) | ((i & 1) << (l - 1));
_FFT(A, lim, 1);
_FFT(B, lim, 1);
for (int i = 0; i <= lim; i++)
A[i] = A[i] * B[i] / lim;
_FFT(A, lim, -1);
for (int i = 0; i < n + m - 1; i++)
res[i] = round(A[i].real);
return n + m - 1;
}

NTT 快速数论变换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
// 快速数论变换(NTT). 时间: O(nlogn)
const int MAX = 100005;
const int MOD = 998244353;
const int MIN_ROOT = 3;
const int INV_ROOT = 332748118;

// 要求模数必须为 NTT 模数, 下面为常用的 MOD / MIN_ROOT / INV_ROOT
// 对于质数 P=2^A*X+1, 可用作长度 2^A 的 NTT. A 越大越适合作为 NTT 模数
const int NTT_MOD[][3] = {
{469762049, 3, 156587350}, // 2 ^ 26 * 7 + 1
{998244353, 3, 332748118}, // 2 ^ 23 * 119 + 1
{1004535809, 3, 334845270} // 2 ^ 21 * 479 + 1
};

#define mm_add(x) ((x) >= MOD ? (x)-MOD : (x))
#define mm_sub(x) ((x) < 0 ? (x) + MOD : (x))

// 快速幂
int quick_pow(int base, int exponent) {
int res = 1 % MOD;
base %= MOD;
while (exponent) {
if (exponent & 1)
res = res * base % MOD;
base = base * base % MOD;
exponent >>= 1;
}
return res;
}

// 逆元
inline int inv(int primal) { return quick_pow(primal, MOD - 2); }

// 快速数论变换. FORWARD: 正变换, 否则为逆变换
int tax[MAX << 2]; // 最好开 4 倍内存, 后同, 严格内存: O(2^ceil(log2(n+m-1)))
template <bool FORWARD>
void _NTT(int A[], int lim) {
for (int i = 0; i < lim; i++)
if (i < tax[i])
swap(A[i], A[tax[i]]);
for (int mid = 1; mid < lim; mid <<= 1) {
int Wn = quick_pow(FORWARD ? MIN_ROOT : INV_ROOT, (MOD - 1) / (mid << 1));
for (int j = 0, R = mid << 1; j < lim; j += R) {
int w = 1;
for (int k = 0; k < mid; k++, w = w * Wn % MOD) {
int x = A[j + k], y = w * A[j + mid + k] % MOD;
A[j + k] = mm_add(x + y);
A[j + mid + k] = mm_sub(x - y);
}
}
}
if (not FORWARD) {
int inv_lim = inv(lim);
for (int i = 0; i < lim; i++)
A[i] = A[i] * inv_lim % MOD;
}
}

// 快速数论卷积. 时间: O((n + m)log(n + m))
// 请保证 res 数组长度 >= n + m - 1
int NTT(int a[], int n, int b[], int m, int res[]) {
static int A[MAX << 2], B[MAX << 2];
memcpy(A, a, sizeof(int) * n);
memcpy(B, b, sizeof(int) * m);
int l = 0, lim = 1;
while (lim < n + m - 1) {
lim <<= 1;
l++;
}
for (int i = 0; i < lim; i++)
tax[i] = tax[i >> 1] >> 1 | (i & 1) << (l - 1);
_NTT<true>(A, lim);
_NTT<true>(B, lim);
for (int i = 0; i < lim; i++)
A[i] = A[i] * B[i] % MOD;
_NTT<false>(A, lim);
memset(B, 0, sizeof(int) * lim); // 未作逆变换的需置零, 以备下次调用
memcpy(res, A, sizeof(int) * (n + m - 1));
return n + m - 1;
}

MTT 任意模数快速数论变换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// 任意模数快速数论变换(MTT). 时间: O(3nlogn)
// 要求 min(n, m)*A^2 < 4e26 (A:值域)
const int MAX = 100005;
const int ANY_MOD = 1000000007;

#define mm_add(x) ((x) >= MOD ? (x)-MOD : (x))
#define mm_sub(x) ((x) < 0 ? (x) + MOD : (x))

// 快速幂
template <int MOD>
constexpr int quick_pow(int base, int exponent) {
int res = 1 % MOD;
base %= MOD;
while (exponent) {
if (exponent & 1)
res = res * base % MOD;
base = base * base % MOD;
exponent >>= 1;
}
return res;
}

// 逆元
template <int MOD>
constexpr inline int inv(int primal) { return quick_pow<MOD>(primal, MOD - 2); }

// 多模数 NTT 参数. 使用常量加速运算
constexpr int NTT_MOD_1 = +469762049, NTT_MIN_ROOT_1 = 3, NTT_INV_ROOT_1 = inv<NTT_MOD_1>(NTT_MIN_ROOT_1);
constexpr int NTT_MOD_2 = +998244353, NTT_MIN_ROOT_2 = 3, NTT_INV_ROOT_2 = inv<NTT_MOD_2>(NTT_MIN_ROOT_2);
constexpr int NTT_MOD_3 = 1004535809, NTT_MIN_ROOT_3 = 3, NTT_INV_ROOT_3 = inv<NTT_MOD_3>(NTT_MIN_ROOT_3);
constexpr __int128_t NTT_CRT_1 = (__int128_t)NTT_MOD_2 * NTT_MOD_3 * inv<NTT_MOD_1>(NTT_MOD_2 * NTT_MOD_3);
constexpr __int128_t NTT_CRT_2 = (__int128_t)NTT_MOD_1 * NTT_MOD_3 * inv<NTT_MOD_2>(NTT_MOD_1 * NTT_MOD_3);
constexpr __int128_t NTT_CRT_3 = (__int128_t)NTT_MOD_1 * NTT_MOD_2 * inv<NTT_MOD_3>(NTT_MOD_1 * NTT_MOD_2);
constexpr __int128_t NTT_CRT_MOD = (__int128_t)NTT_MOD_1 * NTT_MOD_2 * NTT_MOD_3; // 471064322751194440790966273

// 快速数论变换. FORWARD: 正变换, 否则为逆变换
int tax[MAX << 2]; // 最好开 4 倍内存, 后同, 严格内存: O(2^ceil(log2(n+m-1)))
template <int MOD, int MIN_ROOT, int INV_ROOT, bool FORWARD>
void _NTT(int A[], int lim) {
for (int i = 0; i < lim; i++)
if (i < tax[i])
swap(A[i], A[tax[i]]);
for (int mid = 1; mid < lim; mid <<= 1) {
int Wn = quick_pow<MOD>(FORWARD ? MIN_ROOT : INV_ROOT, (MOD - 1) / (mid << 1));
for (int j = 0, R = mid << 1; j < lim; j += R) {
int w = 1;
for (int k = 0; k < mid; k++, w = w * Wn % MOD) {
int x = A[j + k], y = w * A[j + mid + k] % MOD;
A[j + k] = mm_add(x + y);
A[j + mid + k] = mm_sub(x - y);
}
}
}
if (not FORWARD) {
int inv_lim = inv<MOD>(lim);
for (int i = 0; i < lim; i++)
A[i] = A[i] * inv_lim % MOD;
}
}

// 快速数论卷积(已预处理)
template <int MOD, int MIN_ROOT, int INV_ROOT>
void NTT(int a[], int n, int b[], int m, int lim, int res[]) {
static int A[MAX << 2], B[MAX << 2];
memcpy(A, a, sizeof(int) * n);
memcpy(B, b, sizeof(int) * m);
_NTT<MOD, MIN_ROOT, INV_ROOT, true>(A, lim);
_NTT<MOD, MIN_ROOT, INV_ROOT, true>(B, lim);
for (int i = 0; i < lim; i++)
A[i] = A[i] * B[i] % MOD;
_NTT<MOD, MIN_ROOT, INV_ROOT, false>(A, lim);
memset(B, 0, sizeof(int) * lim); // 未作逆变换的需置零, 以备下次调用
memcpy(res, A, sizeof(int) * (n + m - 1));
}

// 任意模数快速数论卷积. 时间: O(3(n + m)log(n + m))
// 请保证 res 数组长度 >= n + m - 1
int MTT(int a[], int n, int b[], int m, int res[]) {
static int RES1[MAX << 1];
static int RES2[MAX << 1];
static int RES3[MAX << 1];
int l = 0, lim = 1;
while (lim < n + m - 1) {
lim <<= 1;
l++;
}
for (int i = 0; i < lim; i++)
tax[i] = tax[i >> 1] >> 1 | (i & 1) << (l - 1);
// 多模数 NTT, 模板函数加速
NTT<NTT_MOD_1, NTT_MIN_ROOT_1, NTT_INV_ROOT_1>(a, n, b, m, lim, RES1);
NTT<NTT_MOD_2, NTT_MIN_ROOT_2, NTT_INV_ROOT_2>(a, n, b, m, lim, RES2);
NTT<NTT_MOD_3, NTT_MIN_ROOT_3, NTT_INV_ROOT_3>(a, n, b, m, lim, RES3);
// 中国剩余定理
for (int i = n + m - 2; i >= 0; i--) {
__int128_t tmp = 0;
tmp += NTT_CRT_1 * RES1[i];
tmp += NTT_CRT_2 * RES2[i];
tmp += NTT_CRT_3 * RES3[i];
res[i] = tmp % NTT_CRT_MOD % ANY_MOD;
}
return n + m - 1;
}