題意
給定一個陣列 \(a\),每次進行以下操作。
- 選擇一個 \(1 \le x \le n\),將 \(a_x := (a_x - 2 ^ {c_x}) \times 2\),然後 \(c_x := c_x + 1\)
如果透過這個操作使得 \(a\) 嚴格遞增,則 \(a\) 是好的。
你希望找到一個長度為 \(n\) 的好的陣列,使得 \(\sum a_i\) 最小,且她的字典序最小。
你需要回答 \(\sum a_i\),同時每次詢問 \(a_{b_i}\)。
\(n \le 10 ^ 9\)
Sol
這道題看起來非常嚇人。
和是好做的。
仔細想想操作,發現一個數 \(x\) 可以操作為 \((x - k) 2 ^ k\)。
注意到我們需要使得 \(x\) 最小,考慮 \((x - k)\) 為偶數的情況。
不難發現當前選擇 \((x - k) 2 ^ k\) 還不如選擇 \((\frac{(x - k)}{2}) 2 ^ {k + 1}\)。
透過各種方式:如打表、瞪眼、yy,最終答案陣列 (設為 \(a\)) 經過若干次操作後滿足遞增的陣列 (設為 \(a'\)) 滿足一段 \(\forall 1 \le i \le A, a'_i = i\)。
我們發現這個 \(A\) 就是 \(a'\) 中出現的最大奇數。
這個性質的證明是 trivial 的。
假如當前詢問的位置為 \(x\),若 \(x \le A\) 直接輸出 \(x\) 按照上述方式得到的最優解即可。
套路地,我們考慮 刪去 前面所有的奇數對答案的影響。
刪去所有奇數,只會剩下偶數,考慮按照上述方式最佳化當前所有數字 (也就是所有數除以 \(2\))。
發現這個子問題和原問題幾乎一摸一樣!
假設我們當前刪了 \(y\) 層,得到地答案為 \(x'\),則答案明顯為 \(x' \times 2 ^ y\)。
雖然層數為 \(\sqrt n\) 的級別,我們依舊無法透過本題。
可以考慮離線下來將所有詢問統一處理,這樣就只需要列舉一遍層數即可。
也可以考慮將當前每一層刪了多少個奇數,以及當前每層可以確定的最大的右端點預處理出來。
複雜度 \(O(\sqrt n)\)。
Code
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <vector>
#define int long long
#define pii pair <int, int>
using namespace std;
#ifdef ONLINE_JUDGE
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;
#endif
int read() {
int p = 0, flg = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') flg = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
p = p * 10 + c - '0';
c = getchar();
}
return p * flg;
}
void write(int x) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x > 9) {
write(x / 10);
}
putchar(x % 10 + '0');
}
bool _stmer;
const int N = 2e5 + 5;
#define fi first
#define se second
int calc(int i) { return ((i / 2) * 2 + 2) * (i / 2) / 2 + (i & 1) * (i + 1) / 2; }
array <int, N> s, h, p, ans;
bool _edmer;
signed main() {
cerr << (&_stmer - &_edmer) / 1024.0 / 1024.0 << "MB\n";
#ifndef cxqghzj
// freopen("halation.in", "r", stdin);
// freopen("halation.out", "w", stdout);
#endif
int n = read(), q = read();
int pos = 0, sum = 0;
for (int i = 1; !pos && i <= 2e5; i++) {
sum += (i + 1) / 2 * i;
if (calc(i) >= n) pos = i;
}
sum -= (calc(pos) - n) * pos;
int m = pos - 1;
write(sum), puts("");
int tot = 0, len = n - calc(m);
for (int i = 1; i <= pos; i++) h[i] = (pos - i) / 2 + 1;
for (int i = 1; i <= calc(pos) - n; i++) h[i * 2 - (pos & 1)]--;
for (int i = 1; i <= pos; i++)
s[i] = tot + h[i] * 2 - 1, tot += h[i];
for (int i = 1; i <= pos; i++) h[i] += h[i - 1];
while (q--) {
int x = read();
int tp = lower_bound(s.begin() + 1, s.begin() + pos + 1, x) - s.begin() - 1;
x -= h[tp];
while (!(x & 1)) tp++, x >>= 1;
write(x + tp), puts("");
}
return 0;
}