的做法有很多,这里说下官方题解的做法(可持久化边分树)。
考虑先把整棵树的边分树建出来,注意到边分树只有叶子节点有原树上的点,其他的点都是原树上的边。然后可持久化边分树,按照将所有点依次插入到对应版本的边分树。这样我们把两个版本差分一下就是一棵只包含了点集的边分树。
考虑如何查询。因为边分树上的一个节点就代表了边分树上对应节点的子树的信息,这里我们维护子树到这条边的距离。那么当我们查询就在遍历差分后的边分树上的所有祖先,然后用没有的那个子树的(就是那个子树的深度和),然后加上即可计算出路径经过这条边的贡献。
考虑如何修改。这里有一个重要的性质:只会交换排列相邻的两项。因为交换排列的项只会使第个版本的边分树多包含了修改后的的,少包含了修改后的(第个版本的边分树是只包含这些点的边分树)。所以我们对第个版本的边分树进行修改,删除新的,插入新的即可。
分析一波复杂度。因为每次边分,联通块大小至少变成原来的,所以树高至多为,因为边分树上每个点至多被树高个祖先计算到,所以初始化时间复杂度。又因为可持久化数据结构的插入和查询复杂度均为。所以插入和查询的总复杂度就是。因此时间复杂度是还有个常数。空间复杂度,初始化的阶段是的,可持久化的部分和插入复杂度相当,为。大概在31
左右,开到33
就很稳了。
/*
_|_| _| _| _|
_| _| _| _|_| _|_|_|_| _| _| _|
_| _| _|_| _| _| _|_|
_| _| _| _| _| _| _| _|
_|_| _| _|_|_|_| _|_| _| _|
*/
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
#include<queue>
#include<set>
//#define ls (rt<<1)
//#define rs (rt<<1|1)
#define vi vector<int>
#define pb push_back
#define mk make_pair
#define pii pair<int,int>
#define rep(i,a,b) for(int i=(a),i##end=(b);i<=i##end;i++)
#define fi first
#define se second
typedef long long ll;
const int maxn=2*2e5+3;
const int inf=1e9+1000;
const int lim=33;
using namespace std;
vi edge[maxn],edgew[maxn];
int p[maxn],head[maxn];
int n,q,cnt=1,bk;
struct gg{
int u,v,w,next;
}side[maxn*2];
void ins(int u,int v,int w){
//cout<<u<<' '<<v<<' '<<w<<endl;
side[++cnt]=(gg){u,v,w,head[u]};head[u]=cnt;
side[++cnt]=(gg){v,u,w,head[v]};head[v]=cnt;
}
void rebuild(int u,int fa){
int last=-1;
for(int i=0;i<edge[u].size();i++){
int v=edge[u][i];if(v==fa)continue;
if(last==-1)ins(u,v,edgew[u][i]),last=u;
else ins(last,++n,0),ins(n,v,edgew[u][i]),last=n;
rebuild(v,u);
}
}
int ls[lim*maxn],rs[lim*maxn],rt[maxn],idx_cnt,si,mysize[maxn],mi;
bool vis[maxn*2];
void get_size(int u,int fa){
mysize[u]=1;
for(int i=head[u];i;i=side[i].next){
int v=side[i].v;if(v==fa||vis[i])continue;
get_size(v,u);mysize[u]+=mysize[v];
}
}
void find_cen(int u,int fa,int n){
for(int i=head[u];i;i=side[i].next){
int v=side[i].v;if(v==fa||vis[i])continue;
find_cen(v,u,n);
if(abs((n-mysize[v])-mysize[v])<mi)
mi=abs((n-mysize[v])-mysize[v]),si=i;
}
}
int cen(int x){
get_size(x,0);
mi=inf;find_cen(x,0,mysize[x]);
if(mi==inf)return -1;
return si;
}
vi route[maxn];
vector<ll> fadis[maxn];
ll lsum[lim*maxn],rsum[lim*maxn];
int lsz[lim*maxn],rsz[lim*maxn],eval[lim*maxn];
void dfs(int u,int fa,int dir,ll dis,int floor){
route[u].pb(dir);fadis[u].pb(dis);
for(int i=head[u];i;i=side[i].next){
int v=side[i].v;if(v==fa||vis[i])continue;
dfs(v,u,dir,dis+side[i].w,floor);
}
}
void divide(int &u,int idx,int floor){
vis[idx]=1; vis[idx^1]=1;
if(!u)u=++idx_cnt;eval[u]=side[idx].w;
int v1=side[idx].u,v2=side[idx].v;
dfs(v1,0,0,0,floor);
dfs(v2,0,1,0,floor);
int l=cen(v1);
int r=cen(v2);
if(l!=-1)divide(ls[u],l,floor+1);
if(r!=-1)divide(rs[u],r,floor+1);
}
void insert(int &now,int last,int floor,int tar){
now=++idx_cnt;
ls[now]=ls[last],rs[now]=rs[last],lsz[now]=lsz[last],rsz[now]=rsz[last],lsum[now]=lsum[last],rsum[now]=rsum[last],eval[now]=eval[last];
if(route[tar].size()<floor)return;
int dir=route[tar][floor-1];
if(!dir)insert(ls[now],ls[last],floor+1,tar),
lsum[now]+=fadis[tar][floor-1],lsz[now]++;
else insert(rs[now],rs[last],floor+1,tar),rsum[now]+=fadis[tar][floor-1],rsz[now]++;
}
void init(){
rebuild(1,0);
divide(rt[0],cen(1),1);
rep(i,1,bk){
insert(rt[i],rt[i-1],1,p[i]);
}
}
int mod(ll x){
return x%(1<<30);
}
ll query(int now,int last,int tar,int floor){
if(route[tar].size()<floor)return 0;
int dir=route[tar][floor-1];
ll ans=0;
if(!dir)ans+=rsum[now]-rsum[last]+(1ll*(fadis[tar][floor-1]+eval[now])*(rsz[now]-rsz[last])),ans+=query(ls[now],ls[last],tar,floor+1);
else ans+=lsum[now]-lsum[last]+(1ll*(fadis[tar][floor-1]+eval[now])*(lsz[now]-lsz[last])),ans+=query(rs[now],rs[last],tar,floor+1);
return ans;
}
int main(){
// freopen("in","r",stdin);
scanf("%d%d",&n,&q);bk=n;
ll last_ans=0;
rep(i,1,n)scanf("%d",&p[i]);
rep(i,1,n-1){
int u,v,w;scanf("%d%d%d",&u,&v,&w);
edge[u].pb(v);edge[v].pb(u);edgew[u].pb(w);edgew[v].pb(w);
}
init();
rep(i,1,q){
int ty;scanf("%d",&ty);
if(ty==1){
int a,b,c;scanf("%d%d%d",&a,&b,&c);
int l=mod(last_ans)^a,r=mod(last_ans)^b,v=mod(last_ans)^c;
last_ans=query(rt[r],rt[l-1],v,1);
printf("%lld\n",last_ans);
}
else{
int a;scanf("%d",&a);
int x=mod(last_ans)^a;
swap(p[x],p[x+1]);
insert(rt[x],rt[x-1],1,p[x]);
}
}
return 0;
}