樹上的等差數列 [樹形dp]

Vocanda發表於2020-08-14

樹上的等差數列

題目描述

給定一棵包含 \(N\) 個節點的無根樹,節點編號 \(1\to N\) 。其中每個節點都具有一個權值,第 \(i\) 個節點的權值是 \(A_i\)

\(Hi\) 希望你能找到樹上的一條最長路徑,滿足沿著路徑經過的節點的權值序列恰好構成等差數列。

輸入格式

第一行包含一個整數 \(N\)

第二行包含 \(N\) 個整數 \(A_1, A_2, ... A_N\)

以下 \(N-1\) 行,每行包含兩個整數 \(U\)\(V\) ,代表節點 \(U\)\(V\) 之間有一條邊相連。

輸出格式

最長等差數列路徑的長度

樣例

樣例輸入

7  
3 2 4 5 6 7 5  
1 2  
1 3  
2 7  
3 4  
3 5  
3 6

樣例輸出

4

資料範圍與提示

對於 \(50\%\) 的資料,\(1 \leqslant N \leqslant 1000\)

對於 \(100\%\) 的資料,\(1 \leqslant N \leqslant 100000, 0 \leqslant A_i \leqslant 100000, 1 \leqslant U, V \leqslant N\)

分析

樹形 \(dp\) 好題。

因為要求的是最長的等差序列,根節點不同,答案也可能不同,所以 \(dp\) 的狀態轉移就定義為 \(f[i][j]\) 表示 \(i\) 節點為根,公差為 \(j\) 時的最長的等差數列,不包括自己。那麼我們就可以愉快的 \(dfs\) 來進行轉移了。

我們記錄一下他自己和他的父親,避免出現死迴圈,每一次先 \(dfs\) 到兒子,遞迴上來,然後就處理出來了公差為 \(\Delta\) 的以兒子為根的所有長度,這時候我們只需要判斷一下此時的 \(\Delta\) 值是否為 \(0\)。如果是,那麼 \(ans\) 的轉移應該是:

\[ans = max(ans,f[x][0] + f[son[x]][0] + 2) \]

因為此時 \(f[x][0]\) 儲存的是其他兒子上最長鏈,所以需要加上當前兒子的最長鏈,因為我們的陣列不儲存自己,所以要加 \(2\)

其他情況就是直接更新 \(ans\) ,他的答案應該是 \(f[x][d] + f[x][-d] + 1\) ,因為他的父親那裡也可能會有鏈,公差為 \(-d\) 就是那個鏈,由於負數下標的問題,我們利用 \(map\) 來儲存,然後輕鬆解決此題。

程式碼

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<map>
#define re register
using namespace std;
const int maxn = 1e5+10;
map <int,int> mp[maxn];
struct Node{
	int v,next;
}e[maxn<<1];
int w[maxn];
int ans = 0;
int head[maxn],tot;
void Add(int x,int y){//建邊
	e[++tot].v = y;
	e[tot].next = head[x];
	head[x] = tot;
}
inline int read(){//快讀
	int s = 0,f = 1;
	char ch = getchar();
	while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
	while(isdigit(ch)){s=s*10+ch-'0';ch=getchar();}
	return s * f;
}
inline void DP(int x,int fa){
	for(int i=head[x];i;i=e[i].next){
		int v = e[i].v;
		if(v == fa)continue;//避免死迴圈
		int d = w[v] - w[x];//計算公差
		DP(v,x);
		if(!d){//公差為0的情況
			ans = max(ans,mp[x][0] + mp[v][0] + 2);
			mp[x][0] = max(mp[x][0],mp[v][0] + 1);
		}
		else{//公差不為0
			mp[x][d] = max(mp[x][d],mp[v][d] + 1);
			ans = max(ans,mp[x][d] + mp[x][-d] + 1);
		}
	}
}

int main(){
	freopen("C.in","r",stdin);
	freopen("C.out","w",stdout);
	int n =read();
	for(re int i = 1;i<=n;++i){w[i]=read();}
	for(re int i = 1;i< n;++i){
		int x = read(),y = read();
		Add(x,y);
		Add(y,x);
	}
	DP(1,0);
	printf("%d\n",ans);
}

相關文章