题目
题意: 给定n个点的树,每个点上有蝴蝶数量a[i],到达一个点i后,所有相邻点j会在对应的t[j]内消散。(t<=3)求从点1出发最多可以获得多少蝴蝶。
思路: 树形dp.
时间复杂度: O(n) 代码:
// Problem: H. Crystalfly // Contest: Codeforces - The 2021 ICPC Asia Nanjing Regional Contest (XXII Open Cup, Grand Prix of Nanjing) // URL: https://codeforces.com/gym/103470/problem/H // Memory Limit: 256 MB // Time Limit: 2000 ms // // Powered by CP Editor (https://cpeditor.org) #include #include #include #include #include #include #include #include #include #include #include #include #include #define OldTomato ios::sync_with_stdio(false),cin.tie(nullptr),cout.tie(nullptr) #define fir(i,a,b) for(int i=a;i<=b;++i) #define mem(a,x) memset(a,x,sizeof(a)) #define p_ priority_queue // round() 四舍五入 ceil() 向上取整 floor() 向下取整 // lower_bound(a.begin(),a.end(),tmp,greater()) 第一个小于等于的 // #define int long long //QAQ using namespace std; typedef complex<double> CP; typedef pair<int,int> PII; typedef long long ll; // typedef __int128 it; const double pi = acos(-1.0); const int INF = 0x3f3f3f3f; const ll inf = 1e18; const int N = 1e5+10; const int M = 1e6+10; const int mod = 1e9+7; const double eps = 1e-6; inline int lowbit(int x){ return x&(-x);} template<typename T>void write(T x) { if(x<0) { putchar('-'); x=-x; } if(x>9) { write(x/10); } putchar(x%10+'0'); } template<typename T> void read(T &x) { x = 0;char ch = getchar();ll f = 1; while(!isdigit(ch)){if(ch == '-')f*=-1;ch=getchar();} while(isdigit(ch)){x = x*10+ch-48;ch=getchar();}x*=f; } int n,m,k,T; ll a[N]; int t[N]; ll f[N]; //以i为树根的最大价值 ll g[N]; //走到a[i],但对应子树的最大价值 vector<int> va[N]; void dfs(int cur,int fa) { // if(va[cur].size()==1) {g[cur] = a[cur]; f[cur] = a[cur]; return;} g[cur] = a[cur]; ll mx = 0,mx2 = 0; //最大的g[z] - ( f[z] - a[z] ). for(int i=0;i<va[cur].size();++i) { int j = va[cur][i]; if(j == fa) continue; dfs(j,cur); ll tmp = g[j] - (f[j] - a[j]); if(tmp > mx) { mx2 = mx; mx = tmp; } else if(tmp > mx2) mx2 = tmp; g[cur] += f[j] - a[j]; } f[cur] = g[cur]; for(int i=0;i<va[cur].size();++i) { int j = va[cur][i]; if(j == fa) continue; f[cur] = max(f[cur],g[cur] + a[j]); if(t[j] == 3) { ll tmp = g[j] - (f[j]-a[j]); if(tmp == mx) { f[cur] = max(f[cur],g[cur] + a[j] + mx2); } else f[cur] = max(f[cur],g[cur] + a[j] + mx); } } } void solve() { read(n); for(int i=1;i<=n;++i) { f[i] = 0; g[i] = 0; va[i].clear(); read(a[i]); } fir(i,1,n) read(t[i]); for(int i=0;i<n-1;++i) { int x,y; read(x),read(y); va[x].push_back(y),va[y].push_back(x); } dfs(1,0); write(f[1]); puts(""); } signed main(void) { // T = 1; // OldTomato; cin>>T; read(T); while(T--) { solve(); } return 0; }
