hdu 6446 Tree and Permutation(dfs+思維)

Nero Alix發表於2018-08-26

題意:給出一顆樹,按節點進行全排列,給你一棵樹,以全排列的第一個樹為根節點,求出根節點到其他點的最短路徑之和,把這些和在相加,求最後結果

分析:整體思路是先建立領接表,然後計算各邊的權重,乘以各邊的權值,然後乘以(n-1)!再乘以2。

問題就是這個權重怎麼算,假設有一棵樹,將其中的一條邊砍斷,這課樹就分為兩部分,部分1的每個結點,都需要經過這條邊,到另外一個部分2的所有結點上走一遍,所以這條邊的權重就可以理解為,部分1的結點數之和i乘以部分2的結點數之和j。
知道了計算權重的方法之後,就很好計算了,先建立一顆以1為根部的樹,然後通過dfs遍歷整棵樹,遇到葉子返回1,父節點等於他的所有子節點之和再加1,於是每個結點就都有了一個數字,代表著他及他之下所有子節點數的和。用一個陣列a存放即可。如果將一棵樹分成兩部分之後,小的那部分的結點數之和就等於他的根部結點a[i]的數字。大的那部分的結點數就等於a[1]-a[i]。然後進行計算即可。
程式碼擼一遍很快,改錯改了幾個小時。
WA了十幾次,總結起來其實只有三個錯誤
1. 一開始忘了題目需要mod 1e9+7了,這個錯誤很快就發現了。
2. 在算好權重之後,還需要遍歷一遍整棵樹,這個遍歷方法出了問題,因為在構建樹的時候,每條路徑是雙向的,遍歷的時候每條路徑都遍歷了兩邊。一開始是加了一個簡單的條件限制遍歷已經走過的路,但是資料量稍微大一點就出錯了。找這個錯誤很簡單,隨便造了幾組資料就發現錯誤了。在修改的時候,一開始一直想著加更多的限制條件,但這很明顯不對,因為資料變得更大就又會出錯。後來想到了,既然每條路徑都遍歷了兩次,那麼直接把權重和sum除以2就行了。這個方法是對的。
3. 權重和在累加的過程中需要有可能會超過longlong的上限,需要一邊加,一邊mod。這個問題找了我至少4個小時,因為在修改了上面兩個錯誤之後,百分之99的資料都可以過了,我自己造了七八組資料,都沒有問題,然後就不知道怎麼辦了。在尋找過程中,我有一次有想到這個問題,但是因為sum先mod再除以2會還是有問題(比如sum mod之後等於1,除以2就直接等於0了),交上去還是WA,就沒有往這個方面去想了,導致後來做了幾個小時的無用功。其實只需要先mod,不需要除以2,因為後面在和階乘相乘的時候也要乘以2,就可以抵消了。
這三個問題解決了,再提交就AC了

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=100005;
const int mod=1e9+7;
ll jc[maxn];
vector<pair<ll,ll> >vis[maxn];
ll qz[maxn];
ll sum=0;
ll fl[maxn];
ll n;
void jjc()
{
    for(int i=2;i<=(n-1);i++)
    {
        sum*=i;
        if(sum>mod)sum%=mod;
    }
    sum%=mod;
}
void dfs(int i)
{
    memset(qz,0,sizeof qz);
    memset(fl,0,sizeof fl);
    sum=0;
    stack<int >s;
    s.push(i);
    fl[i]=1;
    while(!s.empty())
    {
        int flag=0;
        ll now=s.top();
        for(int i=0;i<vis[now].size();i++)
        {
            ll v=vis[now][i].first;
            if(fl[v]==1)
                continue;
            flag=1;
            fl[v]=1;
            s.push(v);
        }
        if(flag==0)
        {
            s.pop();
            for(int i=0;i<vis[now].size();i++)
            {
                ll v=vis[now][i].first;
                qz[now]+=qz[v];
            }
            qz[now]++;
        }
    }
}
int main()
{
    while(cin>>n)
    {
        memset(vis,0,sizeof vis);
        for(int i=0;i<n-1;i++)
        {
            int a,b,c;
            scanf("%d%d%d",&a,&b,&c);
            vis[a].push_back(make_pair(b,c));
            vis[b].push_back(make_pair(a,c));
        }
        dfs(1);
        for(int i=1;i<=n;i++)
        {
            for(int j=0;j<vis[i].size();j++)
            {
                ll v=vis[i][j].first;
                ll minn=min(qz[i],qz[v]);
                ll k=(qz[1]-minn)*minn;
                sum+=(k*vis[i][j].second);
            }
            sum%=mod;
        }
        jjc();
        cout<<sum<<endl;
    }
} 

相關文章