Solution - Codeforces 1970G3 Min-Fund Prison (Hard)

rizynvu發表於2024-08-22

時間 \(\mathcal{O}(\frac{n\sqrt{n}\log n}{\omega})\) 空間 \(\mathcal{O}(\frac{n\log n}{w})\) 的爆標做法。

首先無解當且僅當圖聯通且無割邊。

首先考慮加邊的貢獻。
一個比較直觀的感受就是隻會盡可能少的加邊,即加邊到整個圖連通。

證明考慮刪掉的邊。
如果加多了邊導致刪除後不會有兩個連通塊一定不行;如果加多了邊後最後分出的還是兩個連通塊顯然加多的邊是不必要的。

於是加邊的貢獻就好算了,因為加的邊的數量一定是連通塊數 \(-1\)

接下來考慮刪邊後的貢獻。
顯然的是肯定兩部分大小越接近 \(\frac{n}{2}\) 越優。

考慮兩部分是如何組成的。
能發現除了有一個連通塊會被割邊拆開,其他的連通塊都會整體留下,並且這些連通塊可以任意組合(還有一種情況是選的割邊是加的邊,即就是連通塊任意組合,可以與上一種情況一同處理)。

於是就可以想到用揹包的形式,得到去掉某一個連通塊其他連通塊大小的揹包,這顯然是可以 bitset 最佳化的。

於是一個想法就是缺一分治,即分治時把一邊貢獻到遞迴下去的另一邊處理,但複雜度為 \(\mathcal{O}(\frac{n^2\log n}{\omega})\),還是不太行。

進一步的,考慮到除掉去除的連通塊,對於其他的連通塊其實只關心其大小。
於是在缺一分治時,實際上只需要考慮對出現過的連通塊的大小分治。

具體的,若一種連通塊的大小的出現次數 \(c \ge 1\),就可以先將 \(c - 1\) 個用二進位制分組的形式放入揹包。
那麼缺一分治時,實際上涉及到的元素就是連通塊的大小了。
大小的種類數是 \(\mathcal{O}(\sqrt{n})\) 的,所以這部分的複雜度是 \(\mathcal{O}(\frac{n\sqrt{n}\log n}{\omega})\)

接下來最佳化求解答案。
考慮到因為連通塊大小的和 \(= n\),所以可以直接暴力列舉劃分情況,令其中一個塊劃分到的大小為 \(x\),並欽定 \(x\) 所在塊大小 \(\ge \lceil\frac{n}{2}\rceil\),再去找這個大小實際是多少。

那麼就是找到最靠前的 \(\ge \lceil\frac{n}{2}\rceil - x\) 的位置。

一個想法是判定一下 \(\lceil\frac{n}{2}\rceil - x\) 是否存在,不存在就用 _Find_next\(\mathcal{O}(\frac{n^2}{\omega})\)
但是考慮到當 \(x\) 遞增時 \(\lceil\frac{n}{2}\rceil - x\) 是遞減的,這個最優的位置是可以繼承的,就可以先 _Find_next 問出 \(x = 0\) 的情況,然後掃一遍動態維護。

於是這部分複雜度就是 \(\mathcal{O}(n + \frac{n\sqrt{n}}{w})\)

最後總時間複雜度 \(\mathcal{O}(\frac{n\sqrt n\log n}{\omega})\)
因為分治的每一層都需要開一個 bitset,空間複雜度 \(\mathcal{O}(\frac{n\log n}{\omega})\)

#include<bits/stdc++.h>
using ll = long long;
inline ll pw2(int n) {return 1ll * n * n;}
const int maxn = 1e5 + 10;
int n;
std::vector<int> to[maxn];
int dfn[maxn], low[maxn], ins[maxn], stk[maxn], top, dtot;
std::vector<int> S;
std::vector<int> vis[maxn];
int tot[maxn];
std::vector<int> val;
int dfs(int u, int fa) {
   dfn[u] = low[u] = ++dtot, ins[u] = 1, stk[++top] = u;
   int siz = 1;
   for (int v : to[u]) {
      if (v == fa) continue;
      if (! dfn[v]) {
         int siz_ = dfs(v, u); low[u] = std::min(low[u], low[v]);
         if (low[v] > dfn[u])
            S.push_back(siz_);
         siz += siz_;
      } else if (ins[v])
         low[u] = std::min(low[u], dfn[v]);
   }
   if (low[u] == dfn[u]) {
      int t;
      do {
         t = stk[top--], ins[t] = 0;
      } while (t != u);
   }
   return siz;
}
std::bitset<maxn> B[20];
ll ans;
void solve(int l, int r, int dep = 0) {
   if (l == r) {
      int x = val[l], n2 = n + 1 >> 1;
      int w = B[dep]._Find_next(n2);
      for (int i = 0; i <= x; i++) {
         if (n2 >= i && B[dep][n2 - i])
            w = n2 - i;
         if (vis[x][i] && w < maxn)
            ans = std::min(ans, pw2(w + i) + pw2(n - w - i));
      }
      return ;
   }
   int mid = (l + r) >> 1;
   B[dep + 1] = B[dep];
   for (int i = mid + 1; i <= r; i++)
      B[dep + 1] |= B[dep + 1] << val[i];
   solve(l, mid, dep + 1);
   B[dep + 1] = B[dep];
   for (int i = l; i <= mid; i++)
      B[dep + 1] |= B[dep + 1] << val[i];
   solve(mid + 1, r, dep + 1);
}
inline void Main() {
   int m, co; scanf("%d%d%d", &n, &m, &co);
   for (int i = 1; i <= n; i++)
      to[i].clear();
   for (int i = 1, x, y; i <= m; i++)
      scanf("%d%d", &x, &y), to[x].push_back(y), to[y].push_back(x);
   int cnt = 0;
   dtot = 0, memset(dfn, 0, sizeof(int) * (n + 1));
   for (int i = 1; i <= n; i++)
      tot[i] = 0, vis[i].clear();
   val.clear();
   for (int i = 1; i <= n; i++)
      if (! dfn[i]) {
         cnt++; int siz = dfs(i, 0);
         if (! tot[siz]++)
            vis[siz].resize(siz + 1), val.push_back(siz);
         for (int x : S)
            vis[siz][x] = vis[siz][siz - x] = 1;
         vis[siz][0] = vis[siz][siz] = 1;
         S.clear();
      }
   if (cnt == 1) {
      bool fl = 0;
      for (int i = 1; i < n; i++)
         fl |= vis[n][i];
      if (! fl)
         return puts("-1"), void();
   }
   B[0].reset(), B[0].set(0);
   ans = 1e18;
   for (int x = 1; x <= n; x++) {
      if (tot[x]) {
         val.push_back(x);
         int v = tot[x]; v--;
         for (int i = 1; i <= v; v -= i, i <<= 1)
            B[0] |= B[0] << i * x;
         B[0] |= B[0] << v * x;
      }
   }
   solve(0, val.size() - 1);
   printf("%lld\n", 1ll * co * (cnt - 1) + ans);
}
int main() {
   int T; scanf("%d", &T);
   while (T--) Main();
   return 0;
}

相關文章