当前位置:   article > 正文

FFT代码详解

fft代码

关于FFT原理部分的介绍,在网上已经有很多了,所以在此只讲代码实现部分的内容。

原理可以参考https://www.cnblogs.com/RabbitHu/p/FFT.html

推荐看完它的原理解释再来看这里的代码解释

废话不多说,上代码(多项式乘法)

  1. #include <iostream>
  2. #include <cstdio>
  3. #include <cmath>
  4. #define N 4000001
  5. using namespace std;
  6. struct cp//手写复数类可以卡常
  7. {
  8. double real,imag;
  9. };
  10. cp operator +(cp a,cp b)
  11. {
  12. return (cp){a.real+b.real,a.imag+b.imag};
  13. }
  14. cp operator -(cp a,cp b)
  15. {
  16. return (cp){a.real-b.real,a.imag-b.imag};
  17. }

复数乘法:设$R_{a}$表示$a$的实部系数,$I_{a}$表示$a$的虚部系数

则$a*b$

$=(R_{a}+I_{a})*(R_{b}+I_{b})$

$=R_{a}*R_{b}+R_{a}*I_{b}+R_{b}*I_{a}+I_{a}*I_{b}$

因为$i^2=-1$

所以结果的实部为$R_{a}*R_{b}-I_{a}*I_{b}$

虚部为$R_{a}*I_{b}+R_{b}*I_{a}$

  1. cp operator *(cp a,cp b)
  2. {
  3.   return (cp){a.real*b.real-a.imag*b.imag,a.real*b.imag+a.imag*b.real};
  4. }
  5. double pi=acos(-1.0);
  6. int lim,rev[N],len;
  7. cp w[N],inv[N],a[N],b[N];
  8. void get_w()
  9. {
  10. for(int i=0;i<=lim;i++)
  11. {
  12. double angle=(double)i*2*pi/lim;
  13. w[i].imag=sin(angle);
  14. w[i].real=cos(angle);
  15. inv[i]=(cp){w[i].real,-w[i].imag};
  16. }
  17. }

fft参数解释

$arr:$系数数组,在$fft$后变为点值数组,$arr_{i}$表示将$w^i_n$带入多项式后求得的值

$w:$预处理好的w单位根,在$fft$的时候正常带入即可,在$idft$的时候带入单位根的倒数(具体参见$idft$)void fft(cp *arr,cp *w)

  1. {
  2. for(int i=0;i<lim;i++)
  3. {
  4. //处理每一个系数在分治过程中的实际位置;
  5. //if是因为只需交换一次,所以选择由小的一方来执行
  6. if(i<rev[i]) swap(arr[i],arr[rev[i]]);
  7. }
  8. for(int i=2;i<=lim;i*=2)//枚举区间长度
  9. {
  10. int l=i/2;
  11. for(int j=0;j<lim;j+=i)//枚举区间位置,这些区间是互不相交的
  12. {
  13. //枚举带入的单位根w(k,l),k>=l的单位根也可以在这里一并求出
  14. for(int k=0;k<l;k++)
  15. {

  

意义变更

在这里$arr$的意义从系数变为$w^k_i$的点值,$a_{j,j+i-1}$分别表示将$w^{0,i-1}_i$的点值

下面的的t相当于文首博客中的$w^k_n * A_2(w^k_{n \over 2})$

  1. cp t=arr[j+k+l]*w[lim/i*k];//w(k,i)=w(k/i,1)=w(n*k/i,n)
  2. arr[j+k+l]=arr[j+k]-t;
  3. arr[j+k]=arr[j+k]+t;
  4. }
  5. }
  6. }
  7. }
  8. int main()
  9. {
  10. int n,m;
  11. cin>>n>>m;
  12. for(int i=0;i<=n;i++) scanf("%lf",&a[i].real);
  13. for(int i=0;i<=m;i++) scanf("%lf",&b[i].real);
  14. lim=1;
  15. while(lim<=n+m) len++,lim<<=1;//这样会比用cmath的log要快?
  16. for(int i=0;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
  17. get_w();
  18. fft(a,w);
  19. fft(b,w);
  20. for(int i=0;i<=lim;i++) a[i]=a[i]*b[i];
  21. fft(a,inv);
  22. for(int i=0;i<=n+m;i++) printf("%d ", (int)(a[i].real/lim+0.5));
  23.   //除以lim的原因具体参见idft,0.5是为了四舍五入
  24. }

转载于:https://www.cnblogs.com/linzhuohang/p/11418932.html

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号