CF494D Birthday

先简述一下题意:定义f(u,v)f(u,v)为以1号点为根时,uu点到vv子树内的所有点的距离的平方和减去uu点到vv点子树外的所有点的平方和。多次询问,可以离线。

我们先试着考虑离线询问然后换根。同时维护以当前点为根,每个点到其子树内所有点的距离的和sum1sum1,距离的平方和sum2sum2以及子树大小sizesize

考虑如何维护。我们定义函数trans(u,val)trans(u,val)表示让sum2(u)sum2(u)从子树内每个点距离的平方和变成子树内每个点距离加上valval的平方和。根据初中数学知识就有trans(u,val)=sum2(u)+size×val2+2×val×sum1(u)trans(u,val)=sum2(u)+size\times val^2+2\times val\times sum1(u)。知道了这个以后,如何维护就很简单了,大家可以自己推推看。

然后考虑维护了这个东西之后如何求解答案。对于询问f(u,v)f(u,v)来说,当以1为根时,若uu不在vv子树内,那当uu为根时,vv换根前的子树一定等于换根后的子树。答案就是ans=trans(v,dis(u,v))(sum2(u)trans(v,dis(u,v)))ans=trans(v,dis(u,v))-(sum2(u)-trans(v,dis(u,v)))

然后我们考虑换根前,uuvv子树内的情况。令vv换根前的父亲为fafa,答案就是ans=(sum2(u)trans(fa,dis(fa,u)))trans(fa,dis(fa,u))ans=(sum2(u)-trans(fa,dis(fa,u)))-trans(fa,dis(fa,u))

不需要任何高级数据结构,只需要一个倍增即可。(卓为啥要把这道题放数据结构专题啊喂)

/*
                                                  
  _|_|                              _|  _|    _|  
_|    _|  _|  _|_|  _|_|_|_|        _|  _|  _|    
_|    _|  _|_|          _|          _|  _|_|      
_|    _|  _|          _|      _|    _|  _|  _|    
  _|_|    _|        _|_|_|_|    _|_|    _|    _|  
                                                                                                    
*/ 
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
#include<queue>
#include<set>
//#define ls (rt<<1)
//#define rs (rt<<1|1)
#define vi vector<int>
#define pb push_back
#define mk make_pair
#define pii pair<int,int>
#define rep(i,a,b) for(int i=(a),i##end=(b);i<=i##end;i++)
#define fi first
#define se second
typedef long long ll;
using namespace std;
const int maxn=1e5+100;
const int mod=1e9+7;
ll sum1[maxn],sum2[maxn],size[maxn];
int dfn[maxn],dfncnt,f[maxn][20],d[maxn][20],n,dep[maxn];
vi side[maxn],w[maxn];
int trans(int u,int val){
	return (sum2[u]+size[u]*val%mod*val%mod+2*val*sum1[u]%mod)%mod;
}
void dfs1(int u,int fa){
	f[u][0]=fa;dfn[u]=++dfncnt;size[u]=1;dep[u]=dep[fa]+1;
	for(int i=0;i<side[u].size();i++){
		int v=side[u][i];if(v==fa)continue;
		dfs1(v,u);size[u]+=size[v];d[v][0]=w[u][i];
		sum1[u]=(sum1[u]+sum1[v]+w[u][i]*size[v]%mod)%mod;
		sum2[u]=(sum2[u]+trans(v,w[u][i]))%mod;
	}
}
void init(){
	rep(i,1,19)rep(j,1,n)f[j][i]=f[f[j][i-1]][i-1];
	rep(i,1,19)rep(j,1,n)d[j][i]=(d[j][i-1]+d[f[j][i-1]][i-1])%mod;
}
int get_dis(int u,int v){
	ll ans=0;
	if(dep[u]<dep[v])swap(u,v);
	for(int i=19;i>=0;i--)if(dep[f[u][i]]>=dep[v])ans=(ans+d[u][i])%mod,u=f[u][i];
	if(u==v)return ans;
	for(int i=19;i>=0;i--)if(f[u][i]!=f[v][i])ans=(1ll*ans+d[u][i]+d[v][i])%mod,u=f[u][i],v=f[v][i];
	return (1ll*ans+d[u][0]+d[v][0])%mod;
}
struct Query{
	int u,v,idx;
	bool insubtree;
}q[maxn];
int ans[maxn];
bool cmp(Query a,Query b){return dfn[a.u]<dfn[b.u];}
int now=1,t;
void dfs2(int u,int fa){
	while(now<=t&&q[now].u==u){
		int v=q[now].v,idx=q[now].idx;
		if(q[now].insubtree){
			ans[idx]=(sum2[u]-2*trans(f[v][0],get_dis(f[v][0],u)))%mod;
			ans[idx]=(ans[idx]+mod)%mod;
		}
		else{
			ans[idx]=(2*trans(v,get_dis(u,v))-sum2[u])%mod;
			ans[idx]=(ans[idx]+mod)%mod;
		}
		now++;
	}
	for(int i=0;i<side[u].size();i++){
		int v=side[u][i];if(v==fa)continue;
		ll bku1=sum1[u],bku2=sum2[u],bku3=size[u];
		ll bkv1=sum1[v],bkv2=sum2[v],bkv3=size[v];
		int tmp1=trans(v,w[u][i]),tmp2=trans(u,w[u][i])-trans(v,2*w[u][i]%mod);tmp2=(tmp2%mod+mod)%mod;
		sum1[u]=(bku1-(bkv1+w[u][i]*bkv3)%mod)%mod;sum1[u]=(sum1[u]%mod+mod)%mod;
		sum1[v]=((bku1-(bkv1+w[u][i]*bkv3)%mod)+(bku3-bkv3)*w[u][i]%mod+bkv1)%mod;sum1[v]=(sum1[v]%mod+mod)%mod;
		
		sum2[u]=(bku2-tmp1)%mod;sum2[u]=(sum2[u]%mod+mod)%mod;
		sum2[v]=tmp2+bkv2;sum2[v]=(sum2[v]%mod+mod)%mod;
		
		size[u]=bku3-bkv3;size[v]=bku3;
		dfs2(v,u);
		sum1[u]=bku1,sum2[u]=bku2,size[u]=bku3;
		sum1[v]=bkv1,sum2[v]=bkv2,size[v]=bkv3;
	}
}
int main(){
	ios::sync_with_stdio(0);
	cin>>n;
	rep(i,1,n-1){
		int u,v,val;cin>>u>>v>>val;
		side[u].pb(v);w[u].pb(val);
		side[v].pb(u);w[v].pb(val);
	}
	dfs1(1,0);
	init();cin>>t;
	rep(i,1,t){
		cin>>q[i].u>>q[i].v;q[i].idx=i;
		if(dfn[q[i].v]<=dfn[q[i].u]&&dfn[q[i].u]<=dfn[q[i].v]+size[q[i].v]-1)q[i].insubtree=1;
	}
	sort(q+1,q+1+t,cmp);
	dfs2(1,0);
	rep(i,1,t){
		cout<<(ans[i]%mod+mod)%mod<<endl;
	}
	return 0;
}