poj 3764 最長異或路徑(二進位制trie樹)

細雨欣然發表於2017-03-09

【問題描述】

  給你一棵樹,n個節點,n-1條邊每條邊i都有一個權值wi。定義任意兩點間的權值為:這兩點間的路徑上的所有邊的值的異或。比如a點和b點間有i,j,k三條邊,那麼ab兩點間的權值為:wi^wj^wk。求這個最大的權值(最長異或路徑)。

【輸入格式】

  第一行為n表示節點數目(節點編號為1..n)。
  接下來的n-1行,每行三個整數:u v w,表示一條樹邊(x,y)的權值為w(0<=w<2^31)。

【輸出格式】

  輸出最長異或路徑長度。

【輸入樣例】

4
1 2 3
2 3 4
2 4 6

【輸出樣例】

【樣例解釋】

The xor-longest path is 0->1->2, which has length 7 (=3 ⊕ 4)

【資料範圍】

n<=250000

【來源】

poj 3764

這道題都能想到用dfs來生成根到每一個點的異或路徑,但生成之後的操作就是重點了。
首先我們可以很容易的想到任意2個點直接的異或路徑就是他們到跟的異或路徑的異或值,證明如下:
設2點為x,y,公共祖先是z。z到根的異或路徑是c,x到z的異或路徑是a,y到z的異或路徑是b。可得a^b=a^c^b^c。
不用二進位制trie樹的話很容易想到一個n^2時間複雜的演算法,就是每2個數進行異或。但如果有了二進位制trie樹就可以先生成樹,在再樹上貪心的進行查詢,很容易就可以得到最大值了,時間複雜度(n*log2n)。

#include<cstdlib>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<vector>
using namespace std;
const int maxn=250005;

struct edge
{
    int u,v,w,next;
}e[maxn];
int d[maxn],f[maxn]={0},ch[maxn*33][2]={0},cnt=0,tot=0,n;
int a[33];
bool vis[maxn*33]={0},usd[maxn]={0};

void in(int x)
{
    int k=30,p=0,d;
    while(k>=0)
    {
        if(x&a[k]) d=1;
        else d=0;
        if(!ch[p][d]) ch[p][d]=++cnt;
        p=ch[p][d];
        k--;
    }
    vis[p]=1;
}
void add(int u,int v,int w)
{
    e[++tot]=(edge){u,v,w,f[u]};f[u]=tot;
}
int read()
{
    int x=0;
    char ch=getchar();
    while(ch<'0'||ch>'9') ch=getchar();
    while(ch>='0'&&ch<='9')
    {
        x=x*10+ch-'0';
        ch=getchar();
    }
    return x;
}
void dfs(int i)
{
    for(int k=f[i];k;k=e[k].next)
    {
        int j=e[k].v,c=e[k].w;
        d[j]=d[i]^c;
        in(d[j]);
        dfs(j);
    }
}
int find(int x)
{
    int k=30,p=0,d,y=0;
    while(k>=0)
    {
        if(x&a[k]) d=0;
        else d=1;
        if(!ch[p][d]) d=d^1;
        if(d) x^=a[k];
        p=ch[p][d];
        k--;
    }
    return x;
}
int main()
{
    //freopen("in.txt","r",stdin);
    //freopen("out.txt","w",stdout);
    n=read();
    a[0]=1;
    for(int i=1;i<=30;i++) a[i]=a[i-1]*2;
    int x,y,w;
    for(int i=1;i<n;i++)
    {
        x=read();y=read();w=read();
        add(x,y,w);
        usd[y]=1;
    }
    in(0);
    for(int i=1;i<=n;i++)if(!usd[i])
    dfs(i);
    int ans=0;
    for(int i=1;i<=n;i++)
    ans=max(ans,find(d[i]));

    cout<<ans;
    return 0;
}

相關文章