一道虚树板子题。
虚树的作用就是:有的时候解决一些树上问题时,整棵树大小很大,但是所需要的“关键点”大小很小。为了节省时间 对有个关键点的虚树进行处理,能使得整棵树的大小变成。
虚树在这里:OIWIKI-虚树有很详细的介绍。
其实就是利用栈维护树上的一条链。然后考虑栈顶节点和下一个关键点的关系,如果等于栈顶(在同一条链)就直接入栈。否则分类讨论栈中的节点和的关系:
- 这代表和栈顶节点构成的链在树的“分叉点”下面的情况(为分叉点)。所以直接加入和栈顶之间的边,然后弹栈顶。注意 这一步显然需要用
while
执行多次。 - 这代表和栈顶节点构成的链(不包含两端,因为这里,前面又判断了)上的一个节点就是。所以没入栈,因此添加栈顶和的边,栈顶出栈,入栈即可。
- 这代表。因此加入边,然后出栈即可。
这里有个实现的小技巧:在每次入栈一个新的节点时进行临接表初始化即可。
最后不要忘了将栈中剩余的节点代表的链加进去。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
#include<queue>
#define vi vector<int>
#define pb push_back
#define mk make_pair
#define pii pair<int,int>
#define rep(i,a,b) for(int i=(a),i##end=(b);i<=i##end;i++)
#define fi first
#define se second
typedef long long ll;
using namespace std;
const int maxn=100500;
vi side[maxn],vit[maxn],kp;
int dfn[maxn],f[maxn][23],dep[maxn],dfn_cnt,stk[maxn],top;
void dfs(int u,int fa){
f[u][0]=fa;dep[u]=dep[fa]+1;dfn[u]=++dfn_cnt;
rep(i,1,22)f[u][i]=f[f[u][i-1]][i-1];
rep(i,0,(int)side[u].size()-1){
int v=side[u][i];if(v==fa)continue;
dfs(v,u);
}
}
bool iskp[maxn];
bool cmp(int a,int b){return dfn[a]<dfn[b];}
void roll_back(int k){
rep(i,0,k-1)iskp[kp[i]]=0;
}
void in(int x){
vit[x].clear();stk[++top]=x;
}
int k;
int c[maxn];
ll dp(int u,int fa){
ll tot=0,ans=0;
rep(i,0,(int)vit[u].size()-1){
int v=vit[u][i];if(v==fa)continue;
ans+=dp(v,u);tot+=c[v];
}
if(iskp[u]){
c[u]=1;ans+=tot;
}
else if(tot>1)c[u]=0,ans++;
else c[u]=tot;
return ans;
}
int lca(int u,int v){
if(dep[u]<dep[v])swap(u,v);
for(int i=22;i>=0;i--)if(dep[f[u][i]]>=dep[v])u=f[u][i];
if(u==v)return u;
for(int i=22;i>=0;i--)if(f[u][i]!=f[v][i])u=f[u][i],v=f[v][i];
return f[u][0];
}
int main(){
int n;scanf("%d",&n);
rep(i,1,n-1){
int u,v;scanf("%d%d",&u,&v);side[u].pb(v);side[v].pb(u);
}
dfs(1,0);
int q;scanf("%d",&q);
while(q--){
scanf("%d",&k);kp.clear();
rep(i,1,k){
int u;scanf("%d",&u);kp.pb(u);iskp[u]=1;
}
bool flag=1;
rep(i,0,k-1){
if(iskp[f[kp[i]][0]]){
printf("-1\n");roll_back(k);flag=0;break;
}
}
if(!flag)continue;
sort(kp.begin(),kp.end(),cmp);
stk[top=1]=1;vit[1].clear();
rep(i,0,k-1){
int x=kp[i];
if(x==1)continue;
int l=lca(x,stk[top]);
if(l==stk[top]){in(x);continue;}
while(dfn[l]<dfn[stk[top-1]]){vit[stk[top-1]].pb(stk[top]);vit[stk[top]].pb(stk[top-1]);top--;}
if(dfn[l]>dfn[stk[top-1]]){
vit[l].clear();vit[l].pb(stk[top]);vit[stk[top]].pb(l);stk[top]=l;
}
else {vit[stk[top]].pb(l);vit[l].pb(stk[top]);top--;}
in(x);
}
rep(i,1,top-1){vit[stk[i]].pb(stk[i+1]);vit[stk[i+1]].pb(stk[i]);}
printf("%lld\n",dp(1,0));
roll_back(k);
}
}