F - Two Sequence Queries
Problem Statement
You are given sequences of length $N$, $A=(A_1,A_2,\ldots,A_N)$ and $B=(B_1,B_2,\ldots,B_N)$.
You are also given $Q$ queries to process in order.
There are three types of queries:
1 l r x
: Add $x$ to each of $A_l, A_{l+1}, \ldots, A_r$.2 l r x
: Add $x$ to each of $B_l, B_{l+1}, \ldots, B_r$.3 l r
: Print the remainder of $\displaystyle\sum_{i=l}^r (A_i\times B_i)$ when divided by $998244353$.
Constraints
- $1\leq N,Q\leq 2\times 10^5$
- $0\leq A_i,B_i\leq 10^9$
- $1\leq l\leq r\leq N$
- $1\leq x\leq 10^9$
- All input values are integers.
- There is at least one query of the third type.
Input
The input is given from Standard Input in the following format. Here, $\mathrm{query}_i$ $(1\leq i\leq Q)$ is the $i$-th query to be processed.
$N$ $Q$
$A_1$ $A_2$ $\ldots$ $A_N$
$B_1$ $B_2$ $\ldots$ $B_N$
$\mathrm{query}_1$
$\mathrm{query}_2$
$\vdots$
$\mathrm{query}_Q$
Each query is given in one of the following formats:
$1$ $l$ $r$ $x$
$2$ $l$ $r$ $x$
$3$ $l$ $r$
Output
If there are $K$ queries of the third type, print $K$ lines.
The $i$-th line ($1\leq i\leq K$) should contain the output for the $i$-th query of the third type.
Sample Input 1
5 6
1 3 5 6 8
3 1 2 1 2
3 1 3
1 2 5 3
3 1 3
1 1 3 1
2 5 5 2
3 1 5
Sample Output 1
16
25
84
Initially, $A=(1,3,5,6,8)$ and $B=(3,1,2,1,2)$. The queries are processed in the following order:
- For the first query, print $(1\times 3)+(3\times 1)+(5\times 2)=16$ modulo $998244353$, which is $16$.
- For the second query, add $3$ to $A_2,A_3,A_4,A_5$. Now $A=(1,6,8,9,11)$.
- For the third query, print $(1\times 3)+(6\times 1)+(8\times 2)=25$ modulo $998244353$, which is $25$.
- For the fourth query, add $1$ to $A_1,A_2,A_3$. Now $A=(2,7,9,9,11)$.
- For the fifth query, add $2$ to $B_5$. Now $B=(3,1,2,1,4)$.
- For the sixth query, print $(2\times 3)+(7\times 1)+(9\times 2)+(9\times 1)+(11\times 4)=84$ modulo $998244353$, which is $84$.
Thus, the first, second, and third lines should contain $16$, $25$, and $84$, respectively.
Sample Input 2
2 3
1000000000 1000000000
1000000000 1000000000
3 1 1
1 2 2 1000000000
3 1 2
Sample Output 2
716070898
151723988
Make sure to print the sum modulo $998244353$ for the third type of query.
解題思路
將修改操作統一成同時對區間內的 $a_i$ 和 $b_i$ 進行修改。具體的,對於操作 $1$,對 $i \in [l,r]$ 內的每個 $a_i$ 加上 $x$,每個 $b_i$ 加上 $0$;對於操作 $2$,則每個 $a_i$ 加上 $0$,每個 $b_i$ 加上 $x$。由於涉及到區間加,顯然要用到線段樹,下面考慮需要維護哪些資訊。
當對區間 $[l,r]$ 內的 $a_i$ 加上 $x$,$b_i$ 加上 $y$,原先的 $\sum\limits_{i=l}^{r}{a_i \cdot b_i}$ 就會變成
\begin{align*}
&\sum\limits_{i=l}^{r}{(a_i + x) \cdot (b_i + y)} \\
= &\sum\limits_{i=l}^{r}{a_i \cdot b_i + a_i \cdot y + b_i \cdot x + x \cdot y} \\
= &\sum\limits_{i=l}^{r}{a_i \cdot b_i} + y \cdot \sum\limits_{i=l}^{r}{a_i} + x \cdot \sum\limits_{i=l}^{r}{b_i} + x \cdot y \cdot (r - l + 1)
\end{align*}
因此對於每個維護著區間 $[l,r]$ 的線段樹節點,我們需要維護 $s_1$ 表示區間內 $a_i \cdot b_i$ 的和;$s_2$ 表示區間 $a_i$ 的和;$s_3$ 表示區間 $b_i$ 的和。當對整個節點的 $a_i$ 加上 $x$,$b_i$ 加上 $y$,那麼對應的更新操作就是 $\displaylines{\begin{cases} s_1 \gets s_1 + s_2 \cdot y + s_3 \cdot x + x \cdot y \cdot (r - l + 1) \\ s_2 \gets s_2 + x \cdot (r - l + 1) \\ s_3 \gets s_3 + y \cdot (r - l + 1) \end{cases}}$。
AC 程式碼如下,時間複雜度為 $O(q\log{n})$:
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 2e5 + 5, mod = 998244353;
int a[N], b[N];
struct Node {
int l, r, s1, s2, s3, sum1, sum2;
}tr[N * 4];
void pushup(int u) {
tr[u].s1 = (tr[u << 1].s1 + tr[u << 1 | 1].s1) % mod;
tr[u].s2 = (tr[u << 1].s2 + tr[u << 1 | 1].s2) % mod;
tr[u].s3 = (tr[u << 1].s3 + tr[u << 1 | 1].s3) % mod;
}
void build(int u, int l, int r) {
tr[u] = {l, r};
if (l == r) {
tr[u].s1 = 1ll * a[l] * b[l] % mod;
tr[u].s2 = a[l] % mod;
tr[u].s3 = b[l] % mod;
}
else {
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void upd(int u, int x, int y) {
int len = tr[u].r - tr[u].l + 1;
tr[u].s1 = (tr[u].s1 + 1ll * tr[u].s2 * y + 1ll * tr[u].s3 * x + 1ll * x * y % mod * len) % mod;
tr[u].s2 = (tr[u].s2 + 1ll * x * len) % mod;
tr[u].s3 = (tr[u].s3 + 1ll * y * len) % mod;
tr[u].sum1 = (tr[u].sum1 + x) % mod;
tr[u].sum2 = (tr[u].sum2 + y) % mod;
}
void pushdown(int u) {
upd(u << 1, tr[u].sum1, tr[u].sum2);
upd(u << 1 | 1, tr[u].sum1, tr[u].sum2);
tr[u].sum1 = tr[u].sum2 = 0;
}
void modify(int u, int l, int r, int x, int y) {
if (tr[u].l >= l && tr[u].r <= r) {
upd(u, x, y);
}
else {
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, x, y);
if (r >= mid + 1) modify(u << 1 | 1, l, r, x, y);
pushup(u);
}
}
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].s1;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l >= mid + 1) return query(u << 1 | 1, l, r);
return (query(u << 1, l, r) + query(u << 1 | 1, l, r)) % mod;
}
int main() {
int n, m;
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", a + i);
}
for (int i = 1; i <= n; i++) {
scanf("%d", b + i);
}
build(1, 1, n);
while (m--) {
int op, l, r, x;
scanf("%d %d %d", &op, &l, &r);
if (op == 1) {
scanf("%d", &x);
modify(1, l, r, x, 0);
}
else if (op == 2) {
scanf("%d", &x);
modify(1, l, r, 0, x);
}
else {
printf("%d\n", query(1, l, r));
}
}
return 0;
}
參考資料
Editorial - SuntoryProgrammingContest2024(AtCoder Beginner Contest 357):https://atcoder.jp/contests/abc357/editorial/10187