P8765 [藍橋杯 2021 國 AB] 翻轉括號序列

ltign發表於2024-05-24

本文參考部落格 [藍橋杯 2021 國 AB] 翻轉括號序列(線段樹上二分)

一、問題簡析

線段樹 + 二分

初步分析

( 的值為 1) 的值為 -1,則對於序列 \(a_La_{L+1}a_{L+2}...a_R\),其為合法序列的條件為

\[\begin{cases} \sum_{n=L}^R{a_n}=0 \\ \forall ~k\in [L,R],\sum_{n=L}^k{a_n} \ge 0 \end{cases} \]

用字首和 presum 來表示,則為

\[\begin{cases} \text{presum[R]}=\text{presum[L - 1]} \\ \forall ~k\in [L,R],\text{presum[k]} \ge \text{presum[L - 1]} \end{cases} \]

因此,我們需要維護區間字首和的最小值。

建立線段樹

節點資訊

struct node
{
	int l, r;
	int mmax, mmin;         // [l, r]中字首和的最大值為mmax,最小值為mmin 
	int tag_rev, tag_add;   // tag_rev -- 是否需要翻轉; tag_add -- 待增加的值 
} tree[N << 2];

操作一

操作一要翻轉 \([L,R]\) 中的括號,我們可以將該操作分成兩部分:

\[reverse(L,R)=reverse(1,L-1)+reverse(1,R) \]

為什麼要這樣做呢?我們來看翻轉序列 \([1,R]\)。翻轉後,會對整個數列的字首和產生影響,進而影響維護的 mminmmax 產生影響。

  • 對於 \([1,R]\),相當於 \(\text{presum[1,2, ...,R]}\) 取相反數。因此 mminmmax 交換,並取相反數即可。
  • 對於 \([R+1,n]\),因為只取反了 \([1,R]\),所以在 \(-\text{presum[R]}\) 的基礎上加上原來的數,相當於在原來的字首和上減去兩倍的 \(\text{presum[R]}\)。因此, mminmmax 也要減去兩倍的 \(\text{presum[R]}\)

所以,要翻轉序列 \([1,R]\),要先後進行兩種更新:

  • 更新一,區間 \([1,R]\) 的字首和取反,即 mminmmax 交換,並取相反數。
  • 更新二,區間 \([R+1,n]\) 的字首和減去兩倍的 \(\text{presum[R]}\),即 mminmmax 減去兩倍的 \(\text{presum[R]}\)

因此,我們需要兩個懶惰標記。值得注意的是,tag_rev 會對 tag_add 產生影響——令 tag_add 取相反數,反過來則不會有影響。在進行 pushdown 時,要先更新 tag_rev,再是 tag_add。(作者尚未明白原因)

操作二

操作二要用二分來實現。我們先來看二分的前提條件——單調性。區間 \([L,R]\) 字首和的最小值 mmin\(R\) 單調不增。因此,我們可以利用二分找到最大的 \(R\),滿足區間 \([L,R]\) 字首和的最小值 mmin 大於等於 \(\text{presum[L - 1]}\)。再判斷此時的 \(R\) 是否滿足 \(\text{presum[R]}=\text{presum[L - 1]}\)
因此,query 的作用是查詢區間 \([L,R]\) 字首和的最小值 mmin

需要注意,二分的條件是:

  • \([L,M]\)mmin 小於 \(\text{presum[L - 1]}\),令 \(R=M-1\)
  • \([L,M]\)mmin 大於等於 \(\text{presum[L - 1]}\)\([M,R]\)mmin 大於 \(\text{presum[L - 1]}\),也要令 \(R=M-1\)
  • 否則,令 \(L=M+1\)

二、Code

P8765 [藍橋杯 2021 國 AB] 翻轉括號序列

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

ll quickin(void)
{
	ll ret = 0;
	bool flag = false;
	char ch = getchar();
	while (ch < '0' || ch > '9')
	{
		if (ch == '-')    flag = true;
		ch = getchar();
	}
	while (ch >= '0' && ch <= '9' && ch != EOF)
	{
		ret = ret * 10 + ch - '0';
		ch = getchar();
	}
	if (flag)    ret = -ret;
	return ret;
}

#define lc(p)    p << 1
#define rc(p)    p << 1 | 1
const int N = 1e6 + 5;
int n, m;
int presum[N];
char s[N];
struct node
{
	int l, r;
	int mmax, mmin;         // [l, r]中字首和的最大值為mmax,最小值為mmin 
	int tag_rev, tag_add;   // tag_rev -- 是否需要翻轉; tag_add -- 待增加的值 
} tree[N << 2];

void amend_rev(node &p)
{
	int tmp1 = p.mmax, tmp2 = p.mmin;
	p.mmax = -tmp2;
	p.mmin = -tmp1;
	p.tag_rev ^= 1;
	p.tag_add *= -1;
}

void amend_add(node &p, int val)
{
	p.mmax += val;
	p.mmin += val;
	p.tag_add += val;
}

void pushup(int p)
{
	tree[p].mmax = max(tree[lc(p)].mmax, tree[rc(p)].mmax);
	tree[p].mmin = min(tree[lc(p)].mmin, tree[rc(p)].mmin);
}

void pushdown(int p)
{
	if (tree[p].tag_rev)
	{
		amend_rev(tree[lc(p)]);
		amend_rev(tree[rc(p)]);
		
		tree[p].tag_rev = 0;
	}
	if (tree[p].tag_add)
	{
		amend_add(tree[lc(p)], tree[p].tag_add);
		amend_add(tree[rc(p)], tree[p].tag_add);
		
		tree[p].tag_add = 0;
	}
}

void build(int p, int l, int r)
{
	tree[p] = {l, r, presum[l], presum[l], 0, 0};
	if (l == r)    return;
	
	int m = (l + r) >> 1;
	build(lc(p), l, m);
	build(rc(p), m + 1, r);
	
	pushup(p);
}

void update_rev(int p, int x, int y)
{
	if (x > y)    return;
	
	if (x <= tree[p].l && tree[p].r <= y)
	{
		amend_rev(tree[p]);
		return;
	}
	
	pushdown(p);
	
	int m = (tree[p].l + tree[p].r) >> 1;
	if (x <= m)    update_rev(lc(p), x, y);
	if (y > m)    update_rev(rc(p), x, y);
	
	pushup(p);
}

void update_add(int p, int x, int y, int val)
{
	if (x > y)   return;
	
	if (x <= tree[p].l && tree[p].r <= y)
	{
		amend_add(tree[p], val);
		return;
	}
	
	pushdown(p);
	
	int m = (tree[p].l + tree[p].r) >> 1;
	if (x <= m)    update_add(lc(p), x, y, val);
	if (y > m)    update_add(rc(p), x, y, val);
	
	pushup(p);
}

int query(int p, int x, int y)
{	
	if (x == 0 && y == 0)    return 0;
	if (x <= tree[p].l && tree[p].r <= y)
		return tree[p].mmin;
	
	pushdown(p);
	
	int m = (tree[p].l + tree[p].r) >> 1;
	int ans = 1e8;
	if (x <= m)    ans = min(ans, query(lc(p), x, y));
	if (y > m)    ans = min(ans, query(rc(p), x, y));
	
	return ans;
}

void update(int x, int y)
{
	int val = query(1, x - 1, x - 1);
	update_rev(1, 1, x - 1);
	update_add(1, x, n, -2 * val);
	
	val = query(1, y, y);
	update_rev(1, x, y);
	update_add(1, y + 1, n, -2 * val);
}

int main()
{
	#ifdef LOCAL
	freopen("test.in", "r", stdin);
	#endif
	
	n = quickin(), m = quickin();
	scanf("%s", s + 1);
	for (int i = 1; i <= n; ++i)
	{
		presum[i] = s[i] == '(' ? 1 : -1;
		presum[i] += presum[i - 1];
	}
	
	build(1, 1, n);
	
	for (int i = 0; i < m; ++i)
	{
		int a, b, c;
		a = quickin();
		
		if (a == 1)
		{
			b = quickin(), c = quickin();
			update(b, c);
		}
		else if (a == 2)
		{
			b = quickin();
			int key = query(1, b - 1, b - 1);
			
			int l = b, r = n;
			
			while (l <= r)
			{
				int m = (l + r) >> 1;
				int mmin = query(1, l, m);
				if (mmin < key || mmin >= key && query(1, m, n) > key)
					r = m - 1;
				else
					l = m + 1;
			}
			
			if (r >= b && query(1, r, r) == key)    printf("%d\n", r);
			else    printf("0\n");
		}
	}
	
	return 0;
}