一个解法→
一句话题意:给出1~n的序列,一个组的定义是1或2个相邻的数字,求每个数字最多属于1个组、共1~k个组分别的答案,对998244353取膜。//我语文差你来打我啊
有一个SB的DP算法:dp[i][j]=dp[i-1][j]+dp[i-1][j-1]+dp[i-2][j-1]。其中dp[i][j]表示前i个j组的方案。//lych:这还能不用FFT(NTT)哒
把dp[i]看做多项式,dp[a+b]=dp[a]*dp[b]+dp[a-1]*dp[b-1]*x(雾
这个可以递归求解,即|a-b|<=1,和快速幂一样达到O(logn)。
每层维护2个多项式,dp[i],dp[i+1],手算dp[i+2],就能得到dp[i*2],dp[i*2+1]。没了。
复杂度O(k log k log n),反正能过。对于FFT的理解之后几天再分析吧。。
#include<iostream> #include<cstdio> #include<algorithm> #define ll long long #define P 998244353 #define N 131073 using namespace std; int n,k,tn,tl,r[N],w[2][N],rn; int a[N],b[N],c[N],d[N],e[N],f[N],d1[N],e1[N]; int pow(int a,int b,int c) { int ans=1; for (;b;a=(ll)a*a%c,b>>=1) if (b&1) ans=(ll)ans*a%c; return ans; } void pre(int x) { tl=0;tn=1; while(tn<x)tn<<=1,tl++; tn<<=1;tl++; int W=pow(3,(P-1)/tn,P); w[0][0]=w[1][0]=1; for (int i=1;i<tn;i++) w[0][i]=(ll)w[0][i-1]*W%P; for (int i=1;i<tn;i++) w[1][i]=w[0][tn-i]; for (int i=1;i<tn;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(tl-1)); rn=pow(tn,P-2,P); } void dft(int *a,int f) { for (int i=0;i<tn;i++) if(i<r[i]) swap(a[i],a[r[i]]); for(int i=1;i<tn;i<<=1) for(int j=0,t=tn/(i<<1);j<tn;j+=i<<1) for(int k=0,l=0;k<i;k++,l+=t) { int x=(ll)w[f][l]*a[j+k+i]%P; int y=a[j+k]; a[j+k]=(y+x)%P; a[j+k+i]=(y+P-x)%P; } if(f) for (int i=0;i<tn;i++) a[i]=(ll)a[i]*rn%P; } void Get(int a[],int b[],int c[]) { c[0]=1; for (int i=1;i<=k;i++) c[i]=((ll)b[i]+b[i-1]+a[i-1])%P; for (int i=k+1;i<tn;i++) c[i]=0; } void work(int x) { if (x==0) { a[0]=1;for (int i=1;i<tn;i++)a[i]=0; b[0]=1;b[1]=1;for (int i=2;i<tn;i++)b[i]=0; return; } if (x==1) { a[0]=1;a[1]=1;for (int i=2;i<tn;i++)a[i]=0; b[0]=1;b[1]=3;b[2]=1;for (int i=3;i<tn;i++)b[i]=0; return; } work(x/2-1); Get(a,b,c); for (int i=k+1;i<tn;i++) a[i]=b[i]=c[i]=0; dft(a,0); dft(b,0); dft(c,0); for (int i=0;i<tn;i++) { d[i]=(ll)b[i]*b[i]%P; d1[i]=(ll)a[i]*a[i]%P; e[i]=(ll)b[i]*c[i]%P; e1[i]=(ll)a[i]*b[i]%P; } dft(d,1); dft(e,1); dft(d1,1); dft(e1,1); for (int i=1;i<tn;i++) d[i]=(d[i]+d1[i-1])%P,e[i]=(e[i]+e1[i-1])%P; if (x&1) { Get(d,e,f); for (int i=0;i<tn;i++) { a[i]=i<=k?e[i]:0; b[i]=i<=k?f[i]:0; } } else { for (int i=0;i<tn;i++) { a[i]=i<=k?d[i]:0; b[i]=i<=k?e[i]:0; } } } int main() { scanf("%d%d",&n,&k); pre(max(k+1,4)); work(n); for (int i=1;i<=k;i++) printf("%d ",(a[i]+P)%P); puts(""); }