赞
踩
题目大意:一个长为 n 的数字串,n 一定是偶数,其中只包括 k 种数位,每种数位的范围是( 0 − 9 0 - 9 0−9),允许这个数字串有前导零,如果这个数字串的前半串的数位和与后半串的数位和相等,那么称这个数字串为幸运串,在给定k种数位的情况下,问有多少种长度为 n 的 幸运串。
分析:做的时候基本上想到了正解,首先看怎么得到答案:数字串的前半串和后半串是独立的,若已知前半串长为
n
2
\frac{n}{2}
2n的数字串数字和为 x 的所有方案,记为
d
p
n
2
[
x
]
dp_\frac{n}{2}[x]
dp2n[x],那么显然答案为
∑
i
=
0
m
a
x
(
d
p
n
2
[
i
]
)
2
\quad\sum_{i = 0}^{max} (dp_\frac{n}{2}[i])^2
∑i=0max(dp2n[i])2。关键是如何计算得到
d
p
n
2
[
x
]
dp_\frac{n}{2}[x]
dp2n[x]。
显然
d
p
1
[
x
]
=
1
[
v
i
s
[
x
]
=
1
]
dp_1[x] = 1[vis[x] = 1]
dp1[x]=1[vis[x]=1],vis[x] = 1表示 x 是一个可用的数位,
d
p
2
[
x
]
dp_2[x]
dp2[x] 可以由
d
p
1
[
x
]
dp_1[x]
dp1[x]与自己做卷积得到。同理
d
p
i
[
x
]
dp_i[x]
dpi[x] 可以由
d
p
i
−
1
[
x
]
dp_{i - 1}[x]
dpi−1[x] 与
d
p
i
[
x
]
dp_i[x]
dpi[x] 做卷积得到,不过这个过程太慢,注意到
d
p
i
[
x
]
dp_i[x]
dpi[x] 可以由
d
p
i
2
[
x
]
dp_{\frac{i}{2}}[x]
dp2i[x]与自己做卷积得到,于是可以得到一个分治策略(其实就是快速幂,没有反应过来显得很傻逼),可以做
l
o
g
n
logn
logn次卷积得到最后的答案(然后这里犯了一个另外一个很傻逼的错误:复杂度分析成
O
(
n
∗
l
o
g
n
∗
l
o
g
n
)
O(n*logn*logn)
O(n∗logn∗logn))。
最后大致过程就是:设 d p [ i ] [ j ] dp[i][j] dp[i][j]表示长为 i 的数字串,和为 j 的方案数,转移方程: d p [ i ] [ k ] = ∑ j = 0 m a x d p [ i − 1 ] [ j ] ∗ d p [ 1 ] [ k − j ] dp[i][k] = \sum_{j = 0}^{max} dp[i - 1][j] * dp[1][k - j] dp[i][k]=j=0∑maxdp[i−1][j]∗dp[1][k−j]首先这个过程是卷积,可以用FFT加速。然后又分析得到 d p [ i ] [ k ] = ∑ j = 0 m a x d p [ i 2 ] [ j ] ∗ d p [ i 2 ] [ k − j ] dp[i][k] = \sum_{j = 0}^{max} dp[\frac{i}{2}][j] * dp[\frac{i}{2}][k - j] dp[i][k]=j=0∑maxdp[2i][j]∗dp[2i][k−j]这个过程可以用快速幂加速。最后总复杂度为 n ∗ l o g n n* logn n∗logn
实际的复杂度为 n ∗ l o g n n*logn n∗logn,FFT加速计算卷积的原理是将表达式转成点值表达,然后两个点值表达的式子在对应点值做简单乘积就可以得到两个多项式相乘的点值表达,所以要做 log次卷积,只需要一次正变换,log次点值相乘,一次逆变换即可,复杂度还是 n l o g n nlogn nlogn(傻逼的以为每次都要正变换然后相乘然后再逆变换回来接着下一次)
在快速幂这里有两种方式:一是直接对每个点值进行快速幂,在逆变换后得到的数组仍然是正确答案,另外一种就是做log次点值相乘操作,需要一个额外数组。
由于有模数,要用NTT
代码:
#include<bits/stdc++.h> using namespace std; const int mod = 998244353; typedef long long ll; const int maxn = 4e6 + 10; int n,k; int vis[10]; ll a[maxn],b[maxn]; ll fpow(ll a,ll b) { ll r = 1; while(b) { if(b & 1) r = r * a % mod; a = a * a % mod; b >>= 1; } return r; } void change(ll t[],int len) { for(int i = 1, j = len / 2; i < len - 1; i++) { if(i < j) swap(t[i],t[j]); int k = len / 2; while(j >= k) { j -= k; k /= 2; } if(j < k) j += k; } } void NTT(ll t[],int len,int type) { change(t,len); for(int s = 2; s <= len; s <<= 1) { ll wn = fpow(3,(mod - 1) / s); if(type == -1) wn = fpow(wn,mod - 2); for(int j = 0; j < len; j += s) { ll w = 1; for(int k = 0; k < s / 2; k++) { ll u = t[j + k],v = t[j + k + s / 2] * w % mod; t[j + k] = (u + v) % mod; t[j + k + s / 2] = (u - v + mod) % mod; w = w * wn % mod; } } } if(type == -1) { ll inv = fpow(len,mod - 2); for(int i = 0; i < len; i++) t[i] = t[i] * inv % mod; } } int main() { memset(vis,0,sizeof vis); memset(b,0,sizeof b); memset(a,0,sizeof a); scanf("%d%d",&n,&k);int x,mx = -1; for(int i = 1; i <= k; i++) { scanf("%d",&x); vis[x] = 1; mx = max(mx,x); } n /= 2; mx *= n; int len = 1; while(len <= (mx << 1)) len <<= 1; for(int i = 0; i < len; i++) if(i < 10 && vis[i]) a[i] = 1; b[0] = 1;NTT(a,len,1);NTT(b,len,1); while(n) { if(n & 1) for(int i = 0; i < len; i++) b[i] = b[i] * a[i] % mod; for(int i = 0; i < len; i++) a[i] = a[i] * a[i] % mod; n >>= 1; } NTT(b,len,-1); ll ans = 0; for(int i = 0; i < len; i++) { ans += b[i] * b[i] % mod; ans %= mod; } printf("%lld\n",ans); return 0; }
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。