「 題解 」P2487 [SDOI2011]攔截導彈

沐離發表於2022-02-12

簡單題意

給定 \(n\) 個數對 \((h_i, v_i)\)

求:

  1. 最長不上升子序列的長度。
  2. 對於每個 \(i\),分別求出包含數對 \((h_i, v_i)\) 的最長上升子序列的個數和最長不上升子序列的個數和的比。

資料範圍\(1 \leq n \leq 5 \times {10} ^ 4\)\(\forall 1 \leq i \leq n, 1 \leq h_i, v_i \leq {10}^9\)

分析

問題 \(1\)

先考慮 \(O(n^2)\) 做法,本質與一維的相同。

定義 \(f_i\) 表示以數對 \((h_i, v_i)\) 結尾的最長不上升子序列的長度。

那麼有

\[\begin{aligned} f_i & = \max(1, \max\{ f_j + 1 \}), h_i \leq h_j \wedge v_i \leq v_j \\ \end{aligned} \]

推出了式子後考慮優化,可以使用 \(\texttt{cdq}\) 分治。

對於一個區間 \(\texttt{[l, r]}\) 的轉移。

\(\texttt{int mid = (l + r) >> 1}\)

先遞迴 \(\texttt{cdq(l, mid)}\)

假設 \(\texttt{[l, mid]}\) 已經求出正確答案,即 \(f_{\texttt{l} \sim \texttt{r}}\) 都是正確的。

考慮如何轉移,即 \(\texttt{[l, mid]}\)\(\texttt{[mid + 1, r]}\) 的貢獻。

與三維偏序一樣,合併 \(h\),以 \(v\) 為下標把 \(f\) 存放在樹狀陣列中。

只不過這裡的 \(f\) 需要取最大值。

最後遞迴 \(\texttt{cdq(mid + 1, r)}\)

因為是最後遞迴 \(\texttt{cdq(mid + 1, r)}\),所以詢問不能真正合並,在結束時需要還原成原來的順序。

問題 \(2\)

問題 \(1\) 解決後問題 \(2\) 就簡單了。

只需要在樹狀陣列中再維護一個統計個數陣列即可。

  • 修改的值大於當前最大值就修改。
  • 修改的值等於當前最大值就累加。

因為求的是 包含 數對 \((h_i, v_i)\) 的最長不上升子序列的個數,所以還需要反著求一遍最長不上升子序列的個數。

兩個數相乘再除以總數就是答案,前提是這個數對被包含在至少一個最長不上升子序列中。

溫馨提示,個數可能會超過 \(\texttt{long long}\) 的範圍,建議使用 \(\texttt{double}\)

\(\texttt{code}\)

#include <cstdio>
#include <vector>
#include <utility>
#include <iostream>
#include <algorithm>

int rint() {
	int x = 0, fx = 1;
	char c = getchar();
	
	while (c < '0' || c > '9') {
		fx ^= (c == '-');
		c = getchar();
	}
	
	while ('0' <= c && c <= '9') {
		x = (x << 3) + (x << 1) + (c ^ 48);
		c = getchar();
	}
	
	if (!fx) {
		return -x;
	}
	
	return x;
}

void read(int &x) {
	x = rint();
}

template<typename... Ts>
void read(int &x, Ts &...rest) {
	read(x);
	read(rest...);
}

int Max(int u, int v) {
	return (u > v) ? u : v;
}

int Min(int u, int v) {
	return (u < v) ? u : v;
}

const int MAX_n = 5e4;

int n, Time; // Time 是時間戳優化樹狀陣列,可以不用清空 
int dp1[MAX_n + 5]; // 正 
int dp2[MAX_n + 5]; // 反 
int vis[MAX_n + 5]; // 樹狀陣列時間戳 
int Bit[MAX_n + 5];
double bit[MAX_n + 5];
double num1[MAX_n + 5]; // 正 
double num2[MAX_n + 5]; // 反 

std::vector<int> lsh; // 離散化 v 

struct Missile {
	int idx, h, v;
} q[MAX_n + 5];
Missile tmp[MAX_n + 5];

bool cmph1(Missile x, Missile y) {
	return x.h > y.h;
}

bool cmph2(Missile x, Missile y) {
	return x.h < y.h;
}

int lowbit(int x) {
	return x & (-x);
}

void add(int k, int x, double y) {
	while (k <= n) {
		if (vis[k] != Time) {
			vis[k] = Time;
			Bit[k] = x;
			bit[k] = y;
		} else {
			if (x > Bit[k]) {
				Bit[k] = x;
				bit[k] = y;
			} else if (x == Bit[k]) {
				bit[k] += y;
			}
		}
		
		k += lowbit(k);
	}
}

std::pair<int, double> ask(int k) {
	int resmax = 0;
	double ressum = 0.0;
	
	while (k > 0) {
		if (vis[k] == Time) {
			if (Bit[k] > resmax) {
				resmax = Bit[k];
				ressum = bit[k];
			} else if (Bit[k] == resmax) {
				ressum += bit[k];
			}
		}
		
		k -= lowbit(k);
	}
	
	return std::make_pair(resmax, ressum);
}

void merge1(int L1, int R1, int L2, int R2) {
	++Time;
	int i = L1, j = L2;
	
	for (int k = L1; k <= R2; k++) {
		tmp[k] = q[k];
	}
	
	std::sort(q + L1, q + R1 + 1, cmph1);
	std::sort(q + L2, q + R2 + 1, cmph1);
	
	while (i <= R1 || j <= R2) {
		if (i <= R1 && (j > R2 || q[i].h >= q[j].h)) {
			add((int)lsh.size() + 1 - q[i].v, dp1[q[i].idx], num1[q[i].idx]);
			i++;
		} else {
			std::pair<int, double> now = ask((int)lsh.size() + 1 - q[j].v);
			
			if (now.first + 1 > dp1[q[j].idx]) {
				dp1[q[j].idx] = now.first + 1;
				num1[q[j].idx] = now.second;
			} else if (now.first + 1 == dp1[q[j].idx]) {
				num1[q[j].idx] += now.second;
			}
			
			j++;
		}
	}
	
	for (int k = L1; k <= R2; k++) {
		q[k] = tmp[k];
	}
}

void cdq1(int L, int R) {
	if (L == R) {
		return ;
	}
	
	int Mid = (L + R) >> 1;
	cdq1(L, Mid);
	merge1(L, Mid, Mid + 1, R);
	cdq1(Mid + 1, R);
}

void merge2(int L1, int R1, int L2, int R2) {
	++Time;
	int i = L1, j = L2;
	
	for (int k = L1; k <= R2; k++) {
		tmp[k] = q[k];
	}
	
	std::sort(q + L1, q + R1 + 1, cmph2);
	std::sort(q + L2, q + R2 + 1, cmph2);
	
	while (i <= R1 || j <= R2) {
		if (i <= R1 && (j > R2 || q[i].h <= q[j].h)) {
			add(q[i].v, dp2[q[i].idx], num2[q[i].idx]);
			i++;
		} else {
			std::pair<int, double> now = ask(q[j].v);
			
			if (now.first + 1 > dp2[q[j].idx]) {
				dp2[q[j].idx] = now.first + 1;
				num2[q[j].idx] = now.second;
			} else if (now.first + 1 == dp2[q[j].idx]) {
				num2[q[j].idx] += now.second;
			}
			
			j++;
		}
	}
	
	for (int k = L1; k <= R2; k++) {
		q[k] = tmp[k];
	}
}

void cdq2(int L, int R) {
	if (L == R) {
		return ;
	}
	
	int Mid = (L + R) >> 1;
	cdq2(L, Mid);
	merge2(L, Mid, Mid + 1, R);
	cdq2(Mid + 1, R);
}

signed main() {
	n = rint();
	
	for (int i = 1; i <= n; i++) {
		read(q[i].h, q[i].v);
		q[i].idx = i;
		lsh.push_back(q[i].v);
	}
	
	std::sort(lsh.begin(), lsh.end());
	lsh.resize(std::unique(lsh.begin(), lsh.end()) - lsh.begin());
	
	for (int i = 1; i <= n; i++) {
		q[i].v = std::lower_bound(lsh.begin(), lsh.end(), q[i].v) - lsh.begin() + 1;
		dp1[i] = dp2[i] = 1;
		num1[i] = num2[i] = 1.0;
	}
	
	cdq1(1, n);
	
	for (int i = n / 2; i >= 1; i--) {
		std::swap(q[i], q[n + 1 - i]);
	}
	
	cdq2(1, n);
	int res = 0;
	double sum = 0;
	
	for (int i = 1; i <= n; i++) {
		if (dp1[i] > res) {
			res = dp1[i];
			sum = num1[i];
		} else if (dp1[i] == res) {
			sum += num1[i];
		}
	}
	
	printf("%d\n", res);
	
	for (int i = 1; i <= n; i++) {
		if (dp1[i] + dp2[i] - 1 != res) {
			printf("0.00000");
		} else {
			printf("%.5f", 1.0 * num1[i] * num2[i] / sum);
		}
		
		putchar((i == n) ? '\n' : ' ');
	}
	
	return 0;
}

相關文章