3-Partition 問題

sinkinben發表於2021-06-26

這是演算法考試的最後一題,當時匆匆寫了個基於 Subset Sum 的解法,也沒有考慮是否可行。

問題描述如下:

給定 \(n\) 個正整數 \(a_1 \dots a_n\) ,設下標的整數集合 \(V=\{1,2,3,\dots,n\}\) , 確定是否有三個不相交的子集 \(I,J,K \sub V\) ,滿足:

\[\sum_{i \in I} a_i = \sum_{j \in J}a_j = \sum_{k \in K} a_k = \frac{sum}{3} \]

其中, \(sum\) 是所有元素之和,要求複雜度是關於 \(n\)\(sum\) 的多項式時間解法。

基於 Subset Sum 的解法

Subset Sum 問題:給定一個整數陣列 nums 和整數 target ,問是否存在 nums 的子集,它的和為 target ,每個元素只能使用一次。

顯然 Subset Sum 問題是揹包問題的特殊情況,當揹包問題中所有物品的價值等於體積,那麼就是 Subset Sum 問題。

思路:

  • 假設 subsetSum(nums, target) 能夠在 nums 找到所有和為 target 的子集。
  • 問題等價於找到 2 個子集 \(I, J \sub V\) ,並且 \(\sum{a_i} = \sum{a_j} = sum/3\) ,那麼剩下的元素必然能保證 \(\sum{a_k} = sum/3\) .
  • 通過 Subset Sum 找到所有滿足目標和為 sum/3 的所有子集 \(\mathcal{I}\) ,即 subsetSum(V, target = sum/3)
  • 對於每一個的 \(I \in \mathcal{I}\) ,令 \(V' = V - I\) ,執行 subsetSum(V', target = sum/3) ,如果返回值不為空,那麼說明存在這樣的 \(I,J,K\) 滿足 3-Partition 的條件。

Subset Sum 問題的判定形式通過動態規劃是十分容易解決的,難點在於找出所有這樣的子集。

建議先完成下面的「輸出所有 LCS 的練習」,再進行後文的閱讀。

輸出所有 Subset Sum

定義 dp[i, j] 表示在 nums[0, ..., i] 中不超過 j 的最大和。

轉移方程為:

\[dp[i, j] = \left\{ \begin{aligned} & dp[i-1, j] & \text{ if } j < a_i \\ & \max(dp[i-1, j], dp[i-1, j-a_i]+a_i), & \text{ if } j \ge a_i \end{aligned} \right. \]

最後結果為 dp[n, target] == target .

程式碼實現

vector<vector<int>> subsetSum(vector<int> &nums, int target)
{
    int n = nums.size();
    vector<vector<int>> dp(n + 1, vector<int>(target + 1, 0));
    for (int i = 1; i <= n; i++)
    {
        for (int j = 0; j <= target; j++)
        {
            int x = nums[i - 1];
            if (j >= x) dp[i][j] = max(dp[i - 1][j], dp[i - 1][j - x] + x);
            else dp[i][j] = dp[i - 1][j];
        }
    }

    // print all subsets
    vector<vector<int>> result;
    function<bool(int, int, vector<int>)> getSubsets = [&](int i, int j, vector<int> subset)
    {
        while (i >= 1 && j >= 0)
        {
            int x = nums[i - 1];
            if (j >= x)
            {
                int t = dp[i - 1][j - x] + x;
                if (dp[i - 1][j] > t) i--;
                else if (dp[i - 1][j] < t) j -= x, i--, subset.emplace_back(x);
                else
                {
                    getSubsets(i - 1, j, subset);
                    subset.emplace_back(x), getSubsets(i - 1, j - x, subset);
                    return true;
                }
            }
            else i--;
        }
        result.emplace_back(subset);
        return true;
    };
    if (dp[n][target] == target)
    {
        getSubsets(n, target, vector<int>{});
        return result;
    }
    return {};
}
int main()
{
    vector<int> nums = {1, 3, 7, 8, 9};
    int t = 16;
    auto result = subsetSum(nums, t);
    for (auto &v : result)
    {
        for (int x : v) cout << x << ' ';
        cout << endl;
    }
}

3-Partition

基於上述的 Subset Sum ,我們可以寫出 3-Partition 的程式碼。

bool threePartition(vector<int> &nums)
{
    int sum = accumulate(nums.begin(), nums.end(), 0);
    if (sum % 3 != 0) return false;
    int target = sum / 3;
    auto subsets = subsetSum(nums, target);
    for (auto &I : subsets)
    {
        vector<int> buf(nums.size() - I.size());
        // buf = nums - I
        auto itor = set_symmetric_difference(nums.begin(), nums.end(), I.begin(), I.end(), buf.begin());
        buf.resize(itor - buf.begin());
        if (subsetSum(buf, target).size() != 0) return true;
    }
    return false;
}
int main()
{
    vector<int> nums = {1, 2, 3, 4, 4, 5, 8};
    cout << threePartition(nums) << endl;
}

Subset Sum 的時間複雜度為 \(O(nt)\),而此處 t = sum/3 ,因此 3-Partition 的時間複雜度為 \(O(kn \cdot sum)\)\(k\) 是 集合 \(I\) 的個數,顯然,\(k\) 有可能是指數級別的。

顯然,基於上述操作,我們同樣能找到所有滿足條件的 \(I,J,K\) .

動態規劃

上面是比較容易想到的思路,但這裡的 3-Partition 是一個判定問題,我們只需要給出 YES or NO,而不需要給出具體的 \(I,J,K\) ,因此用動態規劃可以使問題變得簡單。

2-Partition

先考慮 Subset Sum 的一種特殊情況:給定 nums ,問是否存在一個子集 I , 使得 sum(I) = sum(nums) / 2 .

其實換湯不換藥。

int twoPartition(vector<int> &nums)
{
    int sum = accumulate(nums.begin(), nums.end(), 0);
    int n = nums.size();
    if (sum % 2 != 0) return false;
    // dp[i, j] 表示前 j 個數字中,是否存在一個和為 i 的子集(允許為空集)
    vector<vector<int>> dp(sum / 2 + 1, vector<int>(n + 1, false));
    for (int i = 0; i <= n; i++) dp[0][i] = true;
    for (int i = 1; i <= sum / 2; i++) dp[i][0] = false;
    for (int i = 1; i <= sum / 2; i++)
    {
        for (int j = 1; j <= n; j++)
        {
            int x = nums[j - 1];
            if (i >= x) dp[i][j] = dp[i][j - 1] || dp[i - x][j - 1];
            else dp[i][j] = dp[i][j - 1];
        }
    }
    return dp[sum / 2][n];
}

3-Partition

定義 dp[j, k] 表示:nums[1, ..., n],是否存在一個子集,使得它的和為 j ;同時存在另外一個不相交子集,它的和為 k .

注意這裡的前提條件是,在 [1, ..., n] 這個範圍,而且 dp[j, k] = true 當且僅當 2 個子集同時存在。

那麼最後的答案是 dp[sum / 3, sum / 3] .

轉移方程為:

\[dp[j, k] = dp[j-a_i][k] \text{ or } dp[j][k-a_i], \text{ for any } a_i \]

類似於自頂向下的填表順序,轉移方程可以改寫為:

\[dp[j, k] = \text{true} \quad \Rightarrow \quad dp[j+a_i, k] = dp[j, k+a_i] = \text{true}, \quad \text{for any } a_i \]

程式碼實現

int threePartition(vector<int> &A)
{
    int sum = accumulate(A.begin(), A.end(), 0);
    int size = A.size();
    if (sum % 3 != 0) return false;
    vector<vector<int>> dp(sum + 1, vector<int>(sum + 1, 0));
    dp[0][0] = true;
    // process the numbers one by one
    for (int i = 0; i < size; i++)
    {
        for (int j = sum; j >= 0; j--)
        {
            for (int k = sum; k >= 0; k--)
            {
                if (dp[j][k])
                {
                    dp[j + A[i]][k] = true;
                    dp[j][k + A[i]] = true;
                }
            }
        }
    }
    return dp[sum / 3][sum / 3];
}

輸出所有 LCS

輸出一個 LCS 可以參考:https://www.cnblogs.com/sinkinben/p/14536604.html

思路很簡單,在填 dp 表的過程中,轉移路徑實際上就是記錄了 LCS 的結果,我們只需要通過回溯法,找到所有 (alen, blen) => (1, 1) 的路徑即可。

程式碼實現

int lcs(const string &a, const string &b)
{
    int alen = a.length(), blen = b.length();
    vector<vector<int>> dp(alen + 1, vector<int>(blen + 1, 0));
    for (int i = 1; i <= alen; i++)
    {
        for (int j = 1; j <= blen; j++)
        {
            if (a[i - 1] == b[j - 1]) dp[i][j] = dp[i - 1][j - 1] + 1;
            else dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]);
        }
    }
    // print all lcs
    function<void(int, int, string)> printlcs = [&](int i, int j, string str)
    {
        while (i >= 1 && j >= 1)
        {
            if (a[i - 1] == b[j - 1])
                str.push_back(a[i - 1]), i--, j--;
            else
            {
                if (dp[i - 1][j] > dp[i][j - 1]) i--;
                else if (dp[i - 1][j] < dp[i][j - 1]) j--;
                else
                {
                    printlcs(i - 1, j, str);
                    printlcs(i, j - 1, str);
                    return;
                }
            }
        }
        reverse(str.begin(), str.end());
        result.insert(str);
    };
    printlcs(alen, blen, "");
    for (auto &x : result) cout << x << endl;
    return dp[alen][blen];
}
int main()
{
    // string a = "cnblog", b = "belong";
    string a = "xyxxzxyzxy", b = "zxzyyzxxyxxz";
    // string a = "ABCBDAB", b = "BDCABA";
    cout << lcs(a, b) << endl;
}