F - Two Sequence Queries

onlyblues發表於2024-06-09

F - Two Sequence Queries

Problem Statement

You are given sequences of length $N$, $A=(A_1,A_2,\ldots,A_N)$ and $B=(B_1,B_2,\ldots,B_N)$.
You are also given $Q$ queries to process in order.

There are three types of queries:

  • 1 l r x : Add $x$ to each of $A_l, A_{l+1}, \ldots, A_r$.
  • 2 l r x : Add $x$ to each of $B_l, B_{l+1}, \ldots, B_r$.
  • 3 l r : Print the remainder of $\displaystyle\sum_{i=l}^r (A_i\times B_i)$ when divided by $998244353$.

Constraints

  • $1\leq N,Q\leq 2\times 10^5$
  • $0\leq A_i,B_i\leq 10^9$
  • $1\leq l\leq r\leq N$
  • $1\leq x\leq 10^9$
  • All input values are integers.
  • There is at least one query of the third type.

Input

The input is given from Standard Input in the following format. Here, $\mathrm{query}_i$ $(1\leq i\leq Q)$ is the $i$-th query to be processed.

$N$ $Q$
$A_1$ $A_2$ $\ldots$ $A_N$
$B_1$ $B_2$ $\ldots$ $B_N$
$\mathrm{query}_1$
$\mathrm{query}_2$
$\vdots$
$\mathrm{query}_Q$

Each query is given in one of the following formats:

$1$ $l$ $r$ $x$

$2$ $l$ $r$ $x$

$3$ $l$ $r$

Output

If there are $K$ queries of the third type, print $K$ lines.
The $i$-th line ($1\leq i\leq K$) should contain the output for the $i$-th query of the third type.


Sample Input 1

5 6
1 3 5 6 8
3 1 2 1 2
3 1 3
1 2 5 3
3 1 3
1 1 3 1
2 5 5 2
3 1 5

Sample Output 1

16
25
84

Initially, $A=(1,3,5,6,8)$ and $B=(3,1,2,1,2)$. The queries are processed in the following order:

  • For the first query, print $(1\times 3)+(3\times 1)+(5\times 2)=16$ modulo $998244353$, which is $16$.
  • For the second query, add $3$ to $A_2,A_3,A_4,A_5$. Now $A=(1,6,8,9,11)$.
  • For the third query, print $(1\times 3)+(6\times 1)+(8\times 2)=25$ modulo $998244353$, which is $25$.
  • For the fourth query, add $1$ to $A_1,A_2,A_3$. Now $A=(2,7,9,9,11)$.
  • For the fifth query, add $2$ to $B_5$. Now $B=(3,1,2,1,4)$.
  • For the sixth query, print $(2\times 3)+(7\times 1)+(9\times 2)+(9\times 1)+(11\times 4)=84$ modulo $998244353$, which is $84$.

Thus, the first, second, and third lines should contain $16$, $25$, and $84$, respectively.


Sample Input 2

2 3
1000000000 1000000000
1000000000 1000000000
3 1 1
1 2 2 1000000000
3 1 2

Sample Output 2

716070898
151723988

Make sure to print the sum modulo $998244353$ for the third type of query.

解題思路

  將修改操作統一成同時對區間內的 $a_i$ 和 $b_i$ 進行修改。具體的,對於操作 $1$,對 $i \in [l,r]$ 內的每個 $a_i$ 加上 $x$,每個 $b_i$ 加上 $0$;對於操作 $2$,則每個 $a_i$ 加上 $0$,每個 $b_i$ 加上 $x$。由於涉及到區間加,顯然要用到線段樹,下面考慮需要維護哪些資訊。

  當對區間 $[l,r]$ 內的 $a_i$ 加上 $x$,$b_i$ 加上 $y$,原先的 $\sum\limits_{i=l}^{r}{a_i \cdot b_i}$ 就會變成

\begin{align*}
&\sum\limits_{i=l}^{r}{(a_i + x) \cdot (b_i + y)} \\
= &\sum\limits_{i=l}^{r}{a_i \cdot b_i + a_i \cdot y + b_i \cdot x + x \cdot y} \\
= &\sum\limits_{i=l}^{r}{a_i \cdot b_i} + y \cdot \sum\limits_{i=l}^{r}{a_i} + x \cdot \sum\limits_{i=l}^{r}{b_i} + x \cdot y \cdot (r - l + 1)
\end{align*}

  因此對於每個維護著區間 $[l,r]$ 的線段樹節點,我們需要維護 $s_1$ 表示區間內 $a_i \cdot b_i$ 的和;$s_2$ 表示區間 $a_i$ 的和;$s_3$ 表示區間 $b_i$ 的和。當對整個節點的 $a_i$ 加上 $x$,$b_i$ 加上 $y$,那麼對應的更新操作就是 $\displaylines{\begin{cases} s_1 \gets s_1 + s_2 \cdot y + s_3 \cdot x + x \cdot y \cdot (r - l + 1) \\ s_2 \gets s_2 + x \cdot (r - l + 1) \\ s_3 \gets s_3 + y \cdot (r - l + 1) \end{cases}}$。

  AC 程式碼如下,時間複雜度為 $O(q\log{n})$:

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;

const int N = 2e5 + 5, mod = 998244353;

int a[N], b[N];
struct Node {
    int l, r, s1, s2, s3, sum1, sum2;
}tr[N * 4];

void pushup(int u) {
    tr[u].s1 = (tr[u << 1].s1 + tr[u << 1 | 1].s1) % mod;
    tr[u].s2 = (tr[u << 1].s2 + tr[u << 1 | 1].s2) % mod;
    tr[u].s3 = (tr[u << 1].s3 + tr[u << 1 | 1].s3) % mod;
}

void build(int u, int l, int r) {
    tr[u] = {l, r};
    if (l == r) {
        tr[u].s1 = 1ll * a[l] * b[l] % mod;
        tr[u].s2 = a[l] % mod;
        tr[u].s3 = b[l] % mod;
    }
    else {
        int mid = l + r >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void upd(int u, int x, int y) {
    int len = tr[u].r - tr[u].l + 1;
    tr[u].s1 = (tr[u].s1 + 1ll * tr[u].s2 * y + 1ll * tr[u].s3 * x + 1ll * x * y % mod * len) % mod;
    tr[u].s2 = (tr[u].s2 + 1ll * x * len) % mod;
    tr[u].s3 = (tr[u].s3 + 1ll * y * len) % mod;
    tr[u].sum1 = (tr[u].sum1 + x) % mod;
    tr[u].sum2 = (tr[u].sum2 + y) % mod;
}

void pushdown(int u) {
    upd(u << 1, tr[u].sum1, tr[u].sum2);
    upd(u << 1 | 1, tr[u].sum1, tr[u].sum2);
    tr[u].sum1 = tr[u].sum2 = 0;
}

void modify(int u, int l, int r, int x, int y) {
    if (tr[u].l >= l && tr[u].r <= r) {
        upd(u, x, y);
    }
    else {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (l <= mid) modify(u << 1, l, r, x, y);
        if (r >= mid + 1) modify(u << 1 | 1, l, r, x, y);
        pushup(u);
    }
}

int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].s1;
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (r <= mid) return query(u << 1, l, r);
    if (l >= mid + 1) return query(u << 1 | 1, l, r);
    return (query(u << 1, l, r) + query(u << 1 | 1, l, r)) % mod;
}

int main() {
    int n, m;
    scanf("%d %d", &n, &m);
    for (int i = 1; i <= n; i++) {
        scanf("%d", a + i);
    }
    for (int i = 1; i <= n; i++) {
        scanf("%d", b + i);
    }
    build(1, 1, n);
    while (m--) {
        int op, l, r, x;
        scanf("%d %d %d", &op, &l, &r);
        if (op == 1) {
            scanf("%d", &x);
            modify(1, l, r, x, 0);
        }
        else if (op == 2) {
            scanf("%d", &x);
            modify(1, l, r, 0, x);
        }
        else {
            printf("%d\n", query(1, l, r));
        }
    }
    
    return 0;
}

參考資料

  Editorial - SuntoryProgrammingContest2024(AtCoder Beginner Contest 357):https://atcoder.jp/contests/abc357/editorial/10187

相關文章