一开始所有人都以为,答案就是统计每种颜色所有点之间的路径上有多少种不同的颜色。后来才发现这是假的。因为如果两种颜色路径上有其他颜色,那么不仅仅是这个点,颜色的所有点都要连入联通块。
对于这种“选了一个颜色,就必须选其他颜色”的问题,可以想到把关系连边之后用强连通分量解决。我们把得到的关系图用tarjan缩成DAG。然后没有出度的那些点里选最小的就好了(因为选又出度的必定不优秀)。
那么问题来了,如何在所有颜色之间依靠关系连边呢?一种方法是建虚树,然后把虚树上每条边代表的原树上的链用倍增优化连边(或者直接树剖也行?)连一下关系即可。但是这样搞有些麻烦,我们发现这张图连一些重边也不会有啥影响,所以我们每次可以把要加入虚树的所有点和这些点的的LCA之间的链连一条边,这样就很方便了。
有点细节。
/*
_|_| _| _| _|
_| _| _| _|_| _|_|_|_| _| _| _|
_| _| _|_| _| _| _|_|
_| _| _| _| _| _| _| _|
_|_| _| _|_|_|_| _|_| _| _|
*/
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
#include<queue>
#include<set>
//#define ls (rt<<1)
//#define rs (rt<<1|1)
#define vi vector<int>
#define pb push_back
#define mk make_pair
#define pii pair<int,int>
#define rep(i,a,b) for(int i=(a),i##end=(b);i<=i##end;i++)
#define fi first
#define se second
typedef long long ll;
using namespace std;
const int maxn=(2e5+10);
const int po=maxn*20;
int f[maxn][20],vpointid[maxn][20],cnt;
vi edge[maxn],side[po];
int color[maxn],dep[maxn],dfn[po],dfn_cnt;
int ans=1e9;
void dfs1(int u,int fa){
dep[u]=dep[fa]+1;dfn[u]=++dfn_cnt;
f[u][0]=fa;rep(i,1,18)f[u][i]=f[f[u][i-1]][i-1];
vpointid[u][0]=++cnt;if(fa)side[cnt].pb(color[fa]);
rep(i,1,18){
vpointid[u][i]=++cnt;
if(vpointid[u][i-1])side[cnt].pb(vpointid[u][i-1]);
if(vpointid[f[u][i-1]][i-1])side[cnt].pb(vpointid[f[u][i-1]][i-1]);
}
rep(i,0,edge[u].size()-1){
int v=edge[u][i];if(v==fa)continue;
dfs1(v,u);
}
}
int lca(int u,int v){
if(dep[u]<dep[v])swap(u,v);
for(int i=18;i>=0;i--)if(dep[f[u][i]]>=dep[v])u=f[u][i];
if(u==v)return u;
for(int i=18;i>=0;i--)if(f[u][i]!=f[v][i])u=f[u][i],v=f[v][i];
return f[u][0];
}
int stk[po];
vi town[maxn],tmp;
bool ontree[po];
bool cmp1(int a,int b){return dfn[a]<dfn[b];}
void add_chain(int u,int fa,int c){
for(int i=18;i>=0;i--){
if(dep[f[u][i]]>=dep[fa]){
side[c].pb(vpointid[u][i]);
u=f[u][i];
}
}
}
int low[po],scc[po],scc_cnt,top;
int n,k;
vi buc;
void tarjan(int u){
dfn[u]=low[u]=++dfn_cnt;stk[++top]=u;
rep(i,0,(int)(side[u].size())-1){
int v=side[u][i];
if(!dfn[v])tarjan(v),low[u]=min(low[u],low[v]);
else if(!scc[v])low[u]=min(low[u],dfn[v]);
}
if(dfn[u]==low[u]){
buc.clear();
scc_cnt++;int x=0;int sz=0;
do{
x=stk[top--];
scc[x]=scc_cnt;
if(x<=k)sz++;
buc.pb(x);
}while(x!=u);bool flag=0;
rep(j,0,buc.size()-1)if(!flag){
int idx=buc[j];
rep(k,0,(int)(side[idx].size())-1){
int v=side[idx][k];
if(scc[v]!=scc_cnt){
flag=1;break;
}
}
}
if(!flag)ans=min(ans,sz);
}
}
int main(){
scanf("%d%d",&n,&k);cnt=k;
rep(i,1,n-1){
int u,v;scanf("%d%d",&u,&v);
edge[u].pb(v);edge[v].pb(u);
}
rep(i,1,n)scanf("%d",&color[i]),town[color[i]].pb(i);
dfs1(1,0);
rep(i,1,k){
tmp=town[i];int sz=tmp.size();
int l=tmp[0];
rep(j,1,sz-1){
l=lca(tmp[j],l);
}
rep(j,0,(int)(tmp.size())-1){
add_chain(tmp[j],l,i);
}
}
memset(dfn,0,sizeof(dfn));dfn_cnt=0;top=0;
rep(i,1,k)if(!dfn[i])tarjan(i);
cout<<ans-1;
return 0;
}