前言
題目連結:洛谷。
題意簡述
給出長度為 \(n\)(\(n \leq 5 \times 10^5\))的字串 \(\texttt{S}\),\(q\)(\(q \leq 2 \times 10^6\))詢問某一子串的最短迴圈節。\(\texttt{A}\) 是 \(\texttt{B}\) 的迴圈節,當 \(\texttt{B}\) 可以由 \(\texttt{A}\) 重複若干次拼接成。
題目分析
聯想到 KMP 求迴圈節的過程,如果 \(len\) 是 \(\texttt{S}\) 的迴圈節長度,那麼一定有 \(\texttt{S}[1 \ldots n - len] = \texttt{S}[len + 1 \ldots n]\),以及 \(len \mid n\)。當然,只有本身為迴圈節的話,\(len = n\)。
字串相等的過程,可以用字串雜湊 \(\Theta(n) \sim \Theta(1)\) 地搞。但是列舉 \(len\) 是 \(\Theta(\sqrt{n})\) 的,時間複雜度 \(\Theta(n + q \sqrt{n})\),考慮最佳化。
發現 \(q\) 不和 \(n\) 同階,所以想到預處理出每個數的所有因子。預處理用埃氏篩,時間複雜度 \(\Theta(n \log n)\),列舉的時候取決於因子最多的個數,記為 \(k\),在本題資料範圍,應為 \(k = 200\)。時間複雜度 \(\Theta(n \log n + qk)\),能過本題。
當然還可以繼續最佳化。
發現,如果 \(len\) 是答案,那麼 \(k \cdot len\) 肯定也是答案。而 \(len = n\) 肯定是答案。所以考慮反過來計算。
即,將 \(n\) 分解成 \(\prod p_i ^ {k_i}\),答案 \(len = \prod p_i ^ {k'_i}\)。初始 \(k'_i = k_i\),那麼每次就是嘗試將一個 \(k'_i \gets k'_i - 1\),如果得到的 \(len\) 是一個合法迴圈節長度,那就減掉。
由於 \(k'_i\) 之間互不影響,從小的質因數開始嘗試。記 \(f(x)\) 表示 \(x\) 的最小質因數。設 \(ans = len\),然後迴圈判斷 \(\cfrac{ans}{f(len)}\) 能否成為新的答案,可以就讓 \(ans \gets \cfrac{ans}{f(len)}\)。然後 \(len \gets \cfrac{len}{f(len)}\)。直到 \(len = 1\)。這裡 \(len\) 和 \(f(len)\) 就是在不斷搞最小質因數的過程。
預處理可以用線性篩 \(\Theta(n)\) 地搞。查詢的時候,時間複雜度是 \(\Theta(\sum k_i) \leq \mathcal{O}(\log n)\)。總的時間複雜度 \(\mathcal{O}(n + q \log n)\)。
程式碼
略去了快讀快寫。卡卡常最優解。
#include <cstdio>
#include <algorithm>
using namespace std;
int n, q;
char str[500010];
using ull = unsigned long long;
int hav[500010], pri[500010], pcnt;
ull hsh[500010], pw[500010];
inline ull get_hash(int l, int r) {
if (l > r) return 0;
return hsh[r] - hsh[l - 1] * pw[r - l + 1];
}
signed main() {
fread(buf, 1, MAX, stdin);
read(n);
for (int i = 2; i <= n; ++i) {
if (!hav[i]) hav[i] = pri[++pcnt] = i;
for (int j = 1; j <= pcnt && i * pri[j] <= n; ++j) {
hav[i * pri[j]] = pri[j];
if (i % pri[j] == 0) break;
}
}
pw[0] = 1;
for (register int i = 1; i <= n; ++i) {
do str[i] = getchar(); while (str[i] < 'a' || str[i] > 'z');
hsh[i] = (hsh[i - 1] * 131 + str[i] - 'a' + 11);
pw[i] = pw[i - 1] * 131;
}
read(q);
for (register int i = 1, l, r; i <= q; ++i) {
read(l), read(r);
if (get_hash(l, r - 1) == get_hash(l + 1, r)) {
write(1), putchar('\n');
continue;
}
int ans = r - l + 1, len = ans;
while (len > 1) {
if (get_hash(l, r - ans / hav[len]) == get_hash(l + ans / hav[len], r))
ans /= hav[len];
len /= hav[len];
}
write(ans), putchar('\n');
}
fwrite(obuf, 1, o - obuf, stdout);
return 0;
}