본문 바로가기
알고리즘/백준 & swacademy

BOJ 20176 - Needle (FFT)

by sun__ 2021. 1. 2.

www.acmicpc.net/problem/20176

 

참고

koosaga.com/263

 

<문제설명>

세 수열 $a,b,c$ 가 주어질 때, $ a_i+c_j=2 * b_k $ 를 만족하는 쌍 $ (i,j,k) $의 개수를 세는 것

 

<풀이>

x좌표의 범위를 0-60000으로 두고 답을 구하는 식을 만들면 다음과 같다.

\[\sum_{x = 0}^{60000} count_b[x] \times \sum_{y = 0}^{2x} (count_a[y] \times count_c[2x-y])\]

 

fft를 이용해서 합성곱을 구할 수 있다.

\[conv[i] = \sum_{x = 0}^{i} (count_a[x] \times count_c[i - x])\]

 

다시 식을 정리하면..

\[\sum_{x = 0}^{60000} (count_b[x] \times conv[2x])\]

 

 

<코드>

#define _USE_MATH_DEFINES
#include <math.h>
#include <complex>
#define FAST ios_base::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr);
using namespace std;

#define sz(v) ((int)(v).size())
#define all(v) (v).begin(),(v).end()
typedef complex<double> base;

void fft(vector <base>& a, bool invert)
{
    int n = sz(a);
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j >= bit; bit >>= 1) j -= bit;
        j += bit;
        if (i < j) swap(a[i], a[j]);
    }
    for (int len = 2; len <= n; len <<= 1) {
        double ang = 2 * M_PI / len * (invert ? -1 : 1);
        base wlen(cos(ang), sin(ang));
        for (int i = 0; i < n; i += len) {
            base w(1);
            for (int j = 0; j < len / 2; j++) {
                base u = a[i + j], v = a[i + j + len / 2] * w;
                a[i + j] = u + v;
                a[i + j + len / 2] = u - v;
                w *= wlen;
            }
        }
    }
    if (invert) {
        for (int i = 0; i < n; i++) a[i] /= n;
    }
}

void multiply(const vector<int>& a, const vector<int>& b, vector<int>& res)
{
    vector <base> fa(all(a)), fb(all(b));
    int n = 1;
    while (n < sz(a) + sz(b)) n <<= 1;
    fa.resize(n); fb.resize(n);
    fft(fa, false); fft(fb, false);
    for (int i = 0; i < n; i++) fa[i] *= fb[i];
    fft(fa, true);
    res.resize(n);
    for (int i = 0; i < n; i++) res[i] = int(fa[i].real() + (fa[i].real() > 0 ? 0.5 : -0.5));
}

typedef long long ll;
typedef pair<int, int> P;
const int MAX = 6e4 + 1;

int N;
vector<int> ua(MAX), ma(MAX), la(MAX);
vector<int> res;


int main() {
    FAST; 
    cin >> N;
    for (int i = 0, x; i < N; i++) {
        cin >> x;
        ua[x + 3e4] = 1;
    }
    cin >> N;
    for (int i = 0, x; i < N; i++) {
        cin >> x;
        ma[x + 3e4] = 1;
    }
    cin >> N;
    for (int i = 0, x; i < N; i++) {
        cin >> x;
        la[x + 3e4] = 1;
    }

    multiply(ua, la, res);

    ll ans = 0;
    for (int i = 0; i < MAX; i++) {
        ans += ma[i] * res[2 * i];
    }

    cout << ans << '\n';
}