[CSP-S 2023] 種樹

pipipipipi43發表於2024-06-27

#include<bits/stdc++.h>
#define ll long long
#define pb push_back
#define mxn 100003
#define rep(i,a,b) for(int i=a;i<=b;++i)
#define rept(i,a,b) for(int i=a;i<b;++i)
using namespace std;
int n,p[mxn],d[mxn],ct[mxn];
ll a[mxn],b[mxn],c[mxn];
vector<int>g[mxn];
bool v[mxn];
inline __int128 get(ll i,ll n,__int128 a,__int128 b)
//從第1天到第N天,按 
{
    if(a<0){
        ll d=min((b-a-1)/(-a),(__int128)n+1);
        if(d<=i)return n-i+1;
        return n-d+1+(d-i)*b+(d-1+i)*(d-i)/2*a;
    }
    return (n+i)*(n-i+1)/2*a+b*(n-i+1);
}
void dfs(int x,int fa)
{
	for(int i:g[x])
	   if(i!=fa)
	   {
		        dfs(i,x);
		        p[x]=min(p[x],p[i]-1);
	   }   
} 
bool check(int mx)
{
    rep(i,1,n)
    //算出每棵樹,最晚應該在哪個時間點來種 
	{
        if(a[i]>get(1,mx,c[i],b[i]))
		    return 0;
        int l=1,r=n;
        while(l<r)
		{
            int mid=(l+r+1)>>1;
            if(a[i]<=get(mid,mx,c[i],b[i]))
			    l=mid;
            else 
			    r=mid-1;
        }
        if(i==1)
		   l=1;
        p[i]=l;
    }
    dfs(1,0);  //修正每個點最晚的種樹時間 
    rep(i,1,n)
	   ct[i]=0;
    rep(i,1,n)
	{
    	if(p[i]<1)
		    return 0;
    	ct[p[i]]++; //在第p[i]個時間要種樹的行為,要執行多少次 
	}
	rep(i,1,n)
	//列舉時間 
	{
		ct[i]+=ct[i-1];  //統計一共要種多少棵樹 
		if(ct[i]>i) //前i個時間只能種i棵樹 
		   return 0;
	}
    return 1;
}
signed main(){
    scanf("%d",&n);
    rep(i,1,n)scanf("%lld%lld%lld",&a[i],&b[i],&c[i]);
    for(int i=1,x,y;i<n;++i)
	{
        scanf("%d%d",&x,&y);
        g[x].pb(y),g[y].pb(x);
    }
    int l=n,r=1e9;
    while(l<r)
	{
        int mid=(l+r)>>1;
        if(check(mid))r=mid;
        else l=mid+1;
    }
    cout<<l;
    return 0;
}

  

#include <cstdio>
#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <queue>
using namespace std;

const int N = 1e5, E = N << 1;
const long long Max = 1e9;

typedef pair<long long, int> pir;

int n;
long long a[N + 5], b[N + 5], c[N + 5], zero[N + 5];

int head[N + 5], to[E + 5], nxt[E + 5], tot = 1;
void add_edge(int u, int v)
{
	tot++;
	to[tot] = v;
	nxt[tot] = head[u];
	head[u] = tot;
	return ;
}
void add(int u, int v)
{
	add_edge(u, v);
	add_edge(v, u);
	return ;
}

long long d[N + 5];//limit
int sz[N + 5];

void calc_d(int u, long long ans)
{
	long long l = 1, r = n;
	d[u] = -1ll;
	while (l <= r)
	{
		long long mid = (l + r) >> 1;
		__int128 sum = 0, one = 1;

		if (c[u] >= 0)
			sum = one * (ans - mid + 1) * b[u]
				 + one * (mid + ans) * (ans - mid + 1) / 2 * c[u];
		else
		{
			if (mid > zero[u])
				sum = ans - mid + 1;
			else if (ans > zero[u])
				sum = one * (zero[u] - mid + 1) * b[u]
					 + one * (mid + zero[u]) * (zero[u] - mid + 1) / 2 * c[u]
					 + ans - zero[u];
			else
				sum = one * (ans - mid + 1) * b[u]
					 + one * (mid + ans) * (ans - mid + 1) / 2 * c[u];
		}

		if (one * a[u] <= sum)
		{
			d[u] = mid;
			l = mid + 1;
		}
		else
			r = mid - 1;
	}
	return ;
}

int fa[N + 5], in[N + 5];
void dfs(int u, int father)
{
	fa[u] = father;
	in[father]++;
	for (int i = head[u]; i; i = nxt[i])
	{
		int v = to[i];
		if (v == father)
			continue;
		dfs(v, u);
	}
	return ;
}

priority_queue<pir> q;
int seq[N + 5];
bool check(long long ans)
{
	for (int i = 1; i <= n; i++)
	{
		calc_d(i, ans);
		if (d[i] < 0)
			return false;
	}

	dfs(1, 0);
	for (int i = 1; i <= n; i++)
	{
		if (in[i] == 0)
			q.emplace(d[i], i);
	}

	for (int T = n; T > 0; T--)
	{
		int u = q.top().second;
		q.pop();
		seq[T] = u;

		if (fa[u])
		{
			in[fa[u]]--;
			if (in[fa[u]] == 0)
				q.emplace(d[fa[u]], fa[u]);
		}
	}

	for (int i = 1; i <= n; i++)
	{
		if (1ll * i > d[seq[i]])
			return false;
	}
	return true;
}

int main()
{
	// freopen("tree.in", "r", stdin);
	// freopen("tree.out", "w", stdout);

	scanf("%d", &n);
	for (int i = 1; i <= n; i++)
	{
		scanf("%lld%lld%lld", a + i, b + i, c + i);
		if (c[i] < 0)
			zero[i] = (1ll - b[i]) / c[i];
	}
	for (int i = 1, u, v; i < n; i++)
	{
		scanf("%d%d", &u, &v);
		add(u, v);
	}

	long long l = 1, r = Max, ans = 0;
	while (l <= r)
	{
		long long mid = (l + r) >> 1;
		if (check(mid))
		{
			ans = mid;
			r = mid - 1;
		}
		else
			l = mid + 1;
	}
	printf("%lld\n", ans);
	return 0;
}