알고리즘/백준 & swacademy
BOJ 20176 - Needle (FFT)
sun__
2021. 1. 2. 16:15
참고
<문제설명>
세 수열 $a,b,c$ 가 주어질 때, $ a_i+c_j=2 * b_k $ 를 만족하는 쌍 $ (i,j,k) $의 개수를 세는 것
<풀이>
x좌표의 범위를 0-60000으로 두고 답을 구하는 식을 만들면 다음과 같다.
\[\sum_{x = 0}^{60000} count_b[x] \times \sum_{y = 0}^{2x} (count_a[y] \times count_c[2x-y])\]
fft를 이용해서 합성곱을 구할 수 있다.
\[conv[i] = \sum_{x = 0}^{i} (count_a[x] \times count_c[i - x])\]
다시 식을 정리하면..
\[\sum_{x = 0}^{60000} (count_b[x] \times conv[2x])\]
<코드>
#define _USE_MATH_DEFINES
#include <math.h>
#include <complex>
#define FAST ios_base::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr);
using namespace std;
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(),(v).end()
typedef complex<double> base;
void fft(vector <base>& a, bool invert)
{
int n = sz(a);
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j >= bit; bit >>= 1) j -= bit;
j += bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
double ang = 2 * M_PI / len * (invert ? -1 : 1);
base wlen(cos(ang), sin(ang));
for (int i = 0; i < n; i += len) {
base w(1);
for (int j = 0; j < len / 2; j++) {
base u = a[i + j], v = a[i + j + len / 2] * w;
a[i + j] = u + v;
a[i + j + len / 2] = u - v;
w *= wlen;
}
}
}
if (invert) {
for (int i = 0; i < n; i++) a[i] /= n;
}
}
void multiply(const vector<int>& a, const vector<int>& b, vector<int>& res)
{
vector <base> fa(all(a)), fb(all(b));
int n = 1;
while (n < sz(a) + sz(b)) n <<= 1;
fa.resize(n); fb.resize(n);
fft(fa, false); fft(fb, false);
for (int i = 0; i < n; i++) fa[i] *= fb[i];
fft(fa, true);
res.resize(n);
for (int i = 0; i < n; i++) res[i] = int(fa[i].real() + (fa[i].real() > 0 ? 0.5 : -0.5));
}
typedef long long ll;
typedef pair<int, int> P;
const int MAX = 6e4 + 1;
int N;
vector<int> ua(MAX), ma(MAX), la(MAX);
vector<int> res;
int main() {
FAST;
cin >> N;
for (int i = 0, x; i < N; i++) {
cin >> x;
ua[x + 3e4] = 1;
}
cin >> N;
for (int i = 0, x; i < N; i++) {
cin >> x;
ma[x + 3e4] = 1;
}
cin >> N;
for (int i = 0, x; i < N; i++) {
cin >> x;
la[x + 3e4] = 1;
}
multiply(ua, la, res);
ll ans = 0;
for (int i = 0; i < MAX; i++) {
ans += ma[i] * res[2 * i];
}
cout << ans << '\n';
}