CF 293 E Close Vertices (樹的分治+樹狀陣列)

acm_cxlove發表於2013-08-09

轉載請註明出處,謝謝http://blog.csdn.net/ACM_cxlove?viewmode=contents    by---cxlove

題目:給出一棵樹,問有多少條路徑權值和不大於w,長度不大於l。

http://codeforces.com/contest/293/problem/E

有男人八題很相似,但是多了一個限制。

同樣 還是點分治,考慮二元組(到根的路徑權值和,到根的路徑長度)。

按第一維度排序之後,可以用two points查詢權值小不大於w的,然後 用樹狀陣列維護路徑長度。

也就是第一個條件用two points,第二個條件用樹狀陣列維護。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#define lson step << 1
#define rson step << 1 | 1
#define pb(a) push_back(a)
#define mp(a,b) make_pair(a , b)
#define lowbit(x) (x & (-x))
#pragma comment(linker, "/STACK:1024000000,1024000000")    
using namespace std;
typedef long long LL;
const int N = 100005;
struct Edge {
    int v , w , next;
}e[N << 1];
int n , l , w , tot , start[N];
int del[N] = {0} , size[N];
LL ans = 0LL;
void _add (int u , int v , int w) {
    e[tot].v = v ; e[tot].next = start[u];
    e[tot].w = w;
    start[u] = tot ++;
}
void add (int u , int v , int w) {
    _add (u , v , w);
    _add (v , u , w);
}
void calsize (int u , int pre) {
    size[u] = 1;
    for (int i = start[u] ; i != -1 ; i = e[i].next) {
        int v = e[i].v;
        if (v == pre || del[v]) continue;
        calsize (v , u);
        size[u] += size[v];
    }
}
int totalsize , maxsize , rootidx;
void dfs (int u , int pre) {
    int mx = totalsize - size[u];
    for (int i = start[u] ; i != -1 ; i = e[i].next) {
        int v = e[i].v;
        if (v == pre || del[v]) continue;
        mx = max (mx , size[v]);
        dfs (v , u);
    }
    if (mx < maxsize) maxsize = mx , rootidx = u;
}
int search (int r) {
    calsize (r , -1);
    totalsize = size[r];
    maxsize = 1 << 30;
    dfs (r , -1);
    return rootidx;
}
vector<pair<int,int> > sub[N] , all;
int idx , dist[N] , cnt[N];
void gao (int u , int pre) {
    all.pb(mp(dist[u] , cnt[u]));
    sub[idx].pb(mp(dist[u] , cnt[u]));
    for (int i = start[u] ; i != -1 ; i = e[i].next) {
        int v = e[i].v , w = e[i].w;
        if (v == pre || del[v]) continue;
        dist[v] = dist[u] + w;
        cnt[v] = cnt[u] + 1;
        gao (v , u);
    }
}
int s[N] , up;
void add (int x , int val) {
    for (int i = x ; i <= up ; i += lowbit (i)) {
        s[i] += val;
    }
}
int ask (int x) {
    int ret = 0;
    for (int i = x ; i > 0 ; i -= lowbit (i)) {
        ret += s[i];
    }
    return ret;
}
LL fuck (vector<pair<int , int> > &v) {
    LL ret = 0;
    up = 0;
    for (int i = 0 ; i < v.size() ; i ++)
        up = max (up , v[i].second);
    for (int i = 1 ; i <= up ; i ++)
        s[i] = 0;
    for (int i = 0 ; i < v.size() ; i ++)
        add (v[i].second , 1);
    for (int i = 0 , j = v.size() - 1 ; i < v.size() ; i ++) {
        while (j >= i && v[i].first + v[j].first > w) {
            add (v[j].second , -1);
            j --;
        }
        if (j < i) break;
        ret += ask (min(up , (l - v[i].second)));
        add (v[i].second , -1);
    }
    return ret;
}
void solve (int root) {
    root = search (root);
    del[root] = 1;
    if (totalsize == 1) return ;
    idx = 0 ;all.clear();
    for (int i = start[root] ; i != -1 ; i = e[i].next) {
        int v = e[i].v , w = e[i].w;
        if (del[v]) continue;
        sub[idx].clear();
        dist[v] = w ; cnt[v] = 1;
        gao (v , -1);
        sort (sub[idx].begin() , sub[idx].end());
        idx ++;
    }
    sort (all.begin() , all.end());
    ans += fuck (all);
    for (int i = 0 ; i < idx ; i ++) {
        for (int j = 0 ; j < sub[i].size() ; j ++) {
            if (sub[i][j].first <= w && sub[i][j].second <= l) {
                ans ++;
            }
        }
        ans -= fuck (sub[i]);
    }
    for (int i = start[root] ; i != -1 ; i = e[i].next) {
        int v = e[i].v;
        if (del[v]) continue;
        solve (v);
    }
}
int main () {
    // freopen ("input.txt" , "r" , stdin);
    // freopen ("output.txt" , "w" , stdout);
    tot = 0;memset (start , -1 , sizeof(start));
    scanf ("%d %d %d" , &n , &l , &w);
    for (int i = 1 ; i < n ; i ++) {
        int p , d;
        scanf ("%d %d" , &p , &d);
        add (i + 1 , p , d);
    }
    solve (1);
    printf ("%I64d\n" , ans);
    return 0;
}


相關文章