Description
Sample Input
5
5 1 3 2 1
1 2
2 3
2 4
1 5
Sample Output
4 1
Hint
Solution
这题一眼不可做。。。想用随机骗分,结果不会写判断答案是否合法。
正解树形DP,设计状态,我们发现其实在以 为根的子树中,最多只能有一个节点由节点 控制,所以其实我们可以设计状态 表示当前节点 以他为根的子树中有 个节点没有被控制,需要当前节点控制。
其实已经被控制的点完全可以看成一个点,这样就可以发现最多只能有一个被当前节点控制,因为如果有两个要求被控制,那么他们的值是肯定不一样的,这样上面的节点只能整体加,而这两个节点没法自我调节。
转移:
当前这个儿子的子树中空一个节点让 控制
当前这个儿子的子树中空一个节点不控制
那么第一问就很好写了。
第二问要求我们统计方案,我们继续考虑用树形DP,设计状态 为让 的子树全部被控制或者一个点不被控制的方案数,那么根据乘法原理,我们要把所有儿子的方案数乘起来得到这个节点的方案。
但是每个子树的方案是不一样的,因为我们更新时是有 空出一个节点的。
那么我们还得每个儿子遍历一下,看看他是不是空出一个节点,然后加法原理再都加起来。
最后注意一下如果 是 的话,要加上所有的方案数
Code
- #include<bits/stdc++.h>
- using namespace std;
-
- typedef long long LL;
- int n,m,MaxDepth;
- const int N=2e5+5,P=998244353;
- const LL INF=1e18;
- LL f[N][2],cost[N],g[N][2];
- int dep[N];
- vector <int> E[N];
-
- inline char gc(){
- static char buf[1<<20],*p1=buf,*p2=buf;
- if(p1==p2){
- p2=(p1=buf)+fread(buf,1,1<<20,stdin);
- if(p1==p2) return EOF;
- }
- return *p1++;
- }
-
- inline int read(){
- int x=0;char ch=gc();
- while(ch<'0' || ch>'9') ch=gc();
- while(ch<='9' && ch>='0') x=x*10+ch-'0',ch=gc();
- return x;
- }
-
- inline LL power(LL a,int b){
- LL ret=1;
- while(b){
- if(b&1) ret=(ret*a)%P;
- a=(a*a)%P;
- b>>=1;
- }
- return ret;
- }
-
- inline void dfs(int u,int fa){
- dep[u]=dep[fa]+1;
- if(E[u].size()==1 && E[u][0]==fa){
- f[u][1]=0;
- f[u][0]=cost[u];
- g[u][0]=g[u][1]=1;
- MaxDepth=dep[u];
- return;
- }
- LL t=0,All=1;
- for(int i=0;i<(int)E[u].size();++i){
- int v=E[u][i];
- if(v==fa) continue;
- dfs(v,u);
- t+=f[v][0];
- All=(All*g[v][0])%P;
- }
- f[u][0]=t;
- f[u][1]=INF;
- for(int i=0;i<(int)E[u].size();++i){
- int v=E[u][i];
- if(v==fa) continue;
- f[u][0]=min(f[u][0],t-f[v][0]+f[v][1]+cost[u]);
- f[u][1]=min(f[u][1],t-f[v][0]+f[v][1]);
- }
- if(f[u][0]==t) g[u][0]=All%P;
- for(int i=0;i<(int)E[u].size();++i){
- int v=E[u][i];
- if(v==fa) continue;
- LL ret=All*power(g[v][0],P-2)%P*g[v][1]%P;
- if(f[u][0]==t-f[v][0]+f[v][1]+cost[u]) g[u][0]=(g[u][0]+ret)%P;
- if(f[u][1]==t-f[v][0]+f[v][1]) g[u][1]=(g[u][1]+ret)%P;
- }
- }
-
- int main(){
- n=read();
- for(int i=1;i<=n;++i) cost[i]=read();
- for(int i=1;i<n;++i){
- int u=read(),v=read();
- E[u].push_back(v);
- E[v].push_back(u);
- }
- dfs(1,1);
- if(MaxDepth!=n) printf("%lld %lld\n",f[1][0],g[1][0]);
- else {
- int cnt=0;
- for(int i=1;i<=n;++i) cnt=(cnt+(cost[i]==f[1][0]))%P;
- printf("%lld %d\n",f[1][0],cnt);
- }
- return 0;
- }