본문 바로가기
알고리즘/메모

고속 푸리에 변환 (FFT)

by sun__ 2021. 1. 2.

blog.myungwoo.kr/54

비재귀적으로 구현하신 페이지. 많은 분들이 이 코드를 사용하시는 듯 함

 


 

합성곱

합성 곱을 $O(nlogn)$ 에 구할 수 있다. 자세한 설명은 다른 블로그..

 

#define _USE_MATH_DEFINES
#include <math.h>
#include <complex>
#include <vector>
using namespace std;
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(),(v).end()
typedef complex<double> base;

void fft(vector <base>& a, bool invert)
{
    int n = sz(a);
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j >= bit; bit >>= 1) j -= bit;
        j += bit;
        if (i < j) swap(a[i], a[j]);
    }
    for (int len = 2; len <= n; len <<= 1) {
        double ang = 2 * M_PI / len * (invert ? -1 : 1);
        base wlen(cos(ang), sin(ang));
        for (int i = 0; i < n; i += len) {
            base w(1);
            for (int j = 0; j < len / 2; j++) {
                base u = a[i + j], v = a[i + j + len / 2] * w;
                a[i + j] = u + v;
                a[i + j + len / 2] = u - v;
                w *= wlen;
            }
        }
    }
    if (invert) {
        for (int i = 0; i < n; i++) a[i] /= n;
    }
}

void multiply(const vector<int>& a, const vector<int>& b, vector<int>& res)
{
    vector <base> fa(all(a)), fb(all(b));
    int n = 1;
    while (n < sz(a) + sz(b)) n <<= 1;
    fa.resize(n); fb.resize(n);
    fft(fa, false); fft(fb, false);
    for (int i = 0; i < n; i++) fa[i] *= fb[i];
    fft(fa, true);
    res.resize(n);
    for (int i = 0; i < n; i++) res[i] = int(fa[i].real() + (fa[i].real() > 0 ? 0.5 : -0.5));
}

 

 


www.acmicpc.net/problem/14958

 

<문제설명>

RPS의 배열과 나의 배열이 주어질 때, RPS의 앞(prefix) 일부분을 건너 뛸 수 있다고 하자. 이 때 최대 승수를 구하면 된다.

 

<풀이>

내가 R을 낼 때, P를 낼 때, S를 낼 때 각 위치마다 낼 수 있는 점수들의 합을 저장하는 cnt배열을 유지해서 답을 구할 수 있다.

 

합성곱의 특성 때문에 내가 낼 배열을 뒤집은 후 합성곱을 수행해 줘야 한다.

 

<코드>

#define sz(v) ((int)(v).size())
#define all(v) (v).begin(),(v).end()
typedef complex<double> base;

void fft(vector <base>& a, bool invert)
{
    int n = sz(a);
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j >= bit; bit >>= 1) j -= bit;
        j += bit;
        if (i < j) swap(a[i], a[j]);
    }
    for (int len = 2; len <= n; len <<= 1) {
        double ang = 2 * M_PI / len * (invert ? -1 : 1);
        base wlen(cos(ang), sin(ang));
        for (int i = 0; i < n; i += len) {
            base w(1);
            for (int j = 0; j < len / 2; j++) {
                base u = a[i + j], v = a[i + j + len / 2] * w;
                a[i + j] = u + v;
                a[i + j + len / 2] = u - v;
                w *= wlen;
            }
        }
    }
    if (invert) {
        for (int i = 0; i < n; i++) a[i] /= n;
    }
}

void multiply(const vector<int>& a, const vector<int>& b, vector<int>& res)
{
    vector <base> fa(all(a)), fb(all(b));
    int n = 1;
    while (n < sz(a) + sz(b)) n <<= 1;
    fa.resize(n); fb.resize(n);
    fft(fa, false); fft(fb, false);
    for (int i = 0; i < n; i++) fa[i] *= fb[i];
    fft(fa, true);
    res.resize(n);
    for (int i = 0; i < n; i++) res[i] = int(fa[i].real() + (fa[i].real() > 0 ? 0.5 : -0.5));
}

int N, M, cnt[MAX];
string RPS, my;
vector<int> X, Y, res;

int main() {
    FAST; cin >> N >> M >> RPS >> my; 
    reverse(my.begin(), my.end());
 
    //내 선택 : R
    X.resize(N);
    Y.resize(M);
    for (int i = 0; i < N; i++) if (RPS[i] == 'S') X[i] = 1;
    for (int i = 0; i < M; i++) if (my[i] == 'R') Y[i] = 1;

    multiply(X, Y, res);
    for (int i = M - 1; i < M-1+N; i++) cnt[i] += res[i];

    res.clear();
    fill(X.begin(), X.end(), 0);
    fill(Y.begin(), Y.end(), 0);

    //내 선택 : S
    for (int i = 0; i < N; i++) if (RPS[i] == 'P') X[i] = 1;
    for (int i = 0; i < M; i++) if (my[i] == 'S') Y[i] = 1;

    multiply(X, Y, res);
    for (int i = M - 1; i < M - 1 + N; i++) cnt[i] += res[i];


    res.clear();
    fill(X.begin(), X.end(), 0);
    fill(Y.begin(), Y.end(), 0);

    //내 선택 : P
    for (int i = 0; i < N; i++) if (RPS[i] == 'R') X[i] = 1;
    for (int i = 0; i < M; i++) if (my[i] == 'P') Y[i] = 1;

    multiply(X, Y, res);
    for (int i = M - 1; i < M - 1 + N; i++) cnt[i] += res[i];

    int ans = 0;
    for (int i = M - 1; i < M - 1 + N; i++) ans = max(ans, cnt[i]);
    cout << ans << '\n';
}

 

'알고리즘 > 메모' 카테고리의 다른 글

최대 유량 - 디닉  (0) 2020.10.09
이분매칭  (0) 2020.10.07
최대 유량 - 에드몬드 카프  (0) 2020.10.07
난수, 랜덤  (0) 2020.08.28
머지소트 트리  (0) 2020.06.02