题目大意:
一棵树初始只有一个编号为$1$的权值为$w_1$的根。$q(q\le2\times10^5)$次操作,每次可以给出$v,w(w<10^9)$,新建一个结点作为$v$的子结点,权值为$w$;或者给出$u$,求出$f(u)$。定义$f(u)=|S|\cdot\sum_{d\in S}d$,其中$S$为$w_u$与其所有子结点$v$的$f(v)$构成的可重集合。
思路:
首先将树全部建好,用线段树维护DFS序。每次新加结点相当于将这一结点的权值由$0$变为$w$,用线段树计算加上这个点以后的贡献。题目就变成了线段树上单点加、区间乘、区间求和问题。时间复杂度$O(q(\log n+\log w))$,其中$O(\log w)$是求逆元的复杂度。
1 #include<cstdio> 2 #include<cctype> 3 #include<forward_list> 4 using int64=long long; 5 inline int getint() { 6 register char ch; 7 while(!isdigit(ch=getchar())); 8 register int x=ch^'0'; 9 while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0'); 10 return x; 11 } 12 constexpr int N=2e5+1,Q=2e5,mod=1e9+7; 13 int w[N],n,dfn[N],par[N],size[N],s[N]; 14 std::forward_list<int> e[N]; 15 inline void add_edge(const int &u,const int &v) { 16 e[u].push_front(v); 17 } 18 struct Operation { 19 int type,v,u; 20 }; 21 Operation o[N]; 22 void exgcd(const int &a,const int &b,int &x,int &y) { 23 if(!b) { 24 x=1,y=0; 25 return; 26 } 27 exgcd(b,a%b,y,x); 28 y-=a/b*x; 29 } 30 int inv(const int &x) { 31 int ret,tmp; 32 exgcd(x,mod,ret,tmp); 33 return (ret%mod+mod)%mod; 34 } 35 void dfs(const int &x) { 36 size[x]=1; 37 dfn[x]=++dfn[0]; 38 for(auto &y:e[x]) { 39 dfs(y); 40 size[x]+=size[y]; 41 } 42 } 43 class SegmentTree { 44 #define _left <<1 45 #define _right <<1|1 46 private: 47 int val[N<<2],tag[N<<2]; 48 void push_up(const int &p) { 49 val[p]=(val[p _left]+val[p _right])%mod; 50 } 51 void push_down(const int &p) { 52 if(tag[p]==1) return; 53 tag[p _left]=(int64)tag[p _left]*tag[p]%mod; 54 tag[p _right]=(int64)tag[p _right]*tag[p]%mod; 55 val[p _left]=(int64)val[p _left]*tag[p]%mod; 56 val[p _right]=(int64)val[p _right]*tag[p]%mod; 57 tag[p]=1; 58 } 59 public: 60 void build(const int &p,const int &b,const int &e) { 61 tag[p]=1; 62 if(b==e) return; 63 const int mid=(b+e)>>1; 64 build(p _left,b,mid); 65 build(p _right,mid+1,e); 66 } 67 void add(const int &p,const int &b,const int &e,const int &x,const int &v) { 68 (val[p]+=v)%=mod; 69 if(b==e) return; 70 push_down(p); 71 const int mid=(b+e)>>1; 72 if(x<=mid) add(p _left,b,mid,x,v); 73 if(x>mid) add(p _right,mid+1,e,x,v); 74 } 75 void mul(const int &p,const int &b,const int &e,const int &l,const int &r,const int &v) { 76 if(b==l&&e==r) { 77 val[p]=(int64)val[p]*v%mod; 78 tag[p]=(int64)tag[p]*v%mod; 79 return; 80 } 81 push_down(p); 82 const int mid=(b+e)>>1; 83 if(l<=mid) mul(p _left,b,mid,l,std::min(mid,r),v); 84 if(r>mid) mul(p _right,mid+1,e,std::max(mid+1,l),r,v); 85 push_up(p); 86 } 87 int query(const int &p,const int &b,const int &e,const int &l,const int &r) { 88 if(b==l&&e==r) { 89 return val[p]; 90 } 91 push_down(p); 92 int ret=0; 93 const int mid=(b+e)>>1; 94 if(l<=mid) (ret+=query(p _left,b,mid,l,std::min(mid,r)))%=mod; 95 if(r>mid) (ret+=query(p _right,mid+1,e,std::max(mid+1,l),r))%=mod; 96 return ret; 97 } 98 #undef _left 99 #undef _right 100 }; 101 SegmentTree t; 102 int main() { 103 w[n=1]=getint(); 104 const int q=getint(); 105 for(register int i=0;i<q;i++) { 106 o[i].type=getint(); 107 if(o[i].type==1) { 108 o[i].v=par[++n]=getint(); 109 add_edge(par[n],n); 110 w[o[i].u=n]=getint(); 111 } 112 if(o[i].type==2) { 113 o[i].v=getint(); 114 } 115 } 116 dfs(1); 117 s[1]=1; 118 t.build(1,1,n); 119 t.add(1,1,n,1,w[1]); 120 for(register int i=0;i<q;i++) { 121 const int &y=o[i].u,&x=o[i].v; 122 if(o[i].type==1) { 123 t.add(1,1,n,dfn[y],(int64)t.query(1,1,n,dfn[x],dfn[x])*inv(w[x])%mod*w[y]%mod); 124 t.mul(1,1,n,dfn[x],dfn[x]+size[x]-1,(int64)inv(s[x])*(s[x]+1)%mod); 125 s[x]+=s[y]=1; 126 } 127 if(o[i].type==2) { 128 const int ans=t.query(1,1,n,dfn[x],dfn[x]+size[x]-1); 129 if(x==1) { 130 printf("%d\n",ans); 131 continue; 132 } 133 printf("%d\n",int((int64)ans*w[par[x]]%mod*inv(t.query(1,1,n,dfn[par[x]],dfn[par[x]]))%mod)); 134 } 135 } 136 return 0; 137 }