二分答案法

n1ce2cv發表於2024-10-09

二分答案法

  • 估計最終答案的大概範圍

  • 分析問題的答案和給定條件之間的單調性

  • 建立一個 f 函式,當答案固定的情況下,判斷給定的條件是否達標

  • 在最終答案可能的範圍上不斷二分搜尋,每次用 f 函式判斷,直到二分結束,找到最合適的答案

875. 愛吃香蕉的珂珂

#include <vector>
#include <algorithm>

using namespace std;

class Solution {
public:
    // 返回要消耗的時間
    long timeConsuming(vector<int> &piles, int k) {
        long res = 0;
        for (const auto &item: piles)
            // item / k 向上取整,前提都是非負數
            res += (item + k - 1) / k;
        return res;
    }

    // 時間複雜度 O(n * log(max)),額外空間複雜度 O(1)
    int minEatingSpeed(vector<int> &piles, int h) {
        int left = 1;
        int right = 0;
        for (const auto &item: piles)
            right = max(right, item);
        int mid;

        while (left <= right) {
            mid = left + ((right - left) >> 1);
            if (timeConsuming(piles, mid) <= h) {
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }
};

410. 分割陣列的最大值

畫匠問題:

  • 一維陣列表示每個位置的畫完成需要的時間,k 表示畫匠人數
  • 每個畫匠可以畫連續的幾幅畫,畫匠可以並行工作,求最小耗時
  • 其實就是把陣列分成連續的 k 個子陣列,使得所有子陣列中和最大的那個的和儘量小
#include <vector>
#include <algorithm>

using namespace std;

class Solution {
public:
    // 每個連續部分的和不超過 limit 的情況下,需要多少個畫匠完成全部畫作
    int painterNeeded(vector<int> &nums, int limit) {
        int count = 1;
        int sum = 0;
        // 時間複雜度 O(n)
        for (const auto &num: nums) {
            // 表示完成不了
            if (num > limit) return INT_MAX;
            if (sum + num > limit) {
                count++;
                sum = num;
            } else {
                sum += num;
            }
        }
        return count;
    }

    // 時間複雜度 O(n * log(sum)),額外空間複雜度 O(1)
    int splitArray(vector<int> &nums, int k) {
        long left = 0;
        long right = 0;
        for (const auto &item: nums)
            right += item;
        long mid;

        while (left <= right) {
            mid = left + ((right - left) >> 1);
            if (painterNeeded(nums, mid) <= k) {
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }
};

機器人跳躍問題

#include <vector>
#include <iostream>
#include <algorithm>

using namespace std;

// 以初始能量 energy 能否走完陣列
bool finished(vector<int> &nums, int energy, int maxH) {
    for (const auto &item: nums) {
        energy += (energy - item);
        // 如果超過高度最大值,後面肯定通關了,可以提前返回
        if (energy >= maxH) return true;
        if (energy < 0) return false;
    }
    return true;
}

// 時間複雜度 O(n * log(maxH)),額外空間複雜度 O(1)
int main() {
    int n;
    cin >> n;
    vector<int> nums(n);
    int maxH = 0;
    for (int i = 0; i < n; ++i) {
        cin >> nums[i];
        maxH = max(maxH, nums[i]);
    }
    int left = 0;
    int right = maxH;
    int mid;

    while (left <= right) {
        mid = left + ((right - left) >> 1);
        if (finished(nums, mid, maxH)) {
            right = mid - 1;
        } else {
            left = mid + 1;
        }
    }
    cout << left;
}

719. 找出第 K 小的數對距離

#include <vector>
#include <algorithm>

using namespace std;

class Solution {
public:
    // 返回任意兩數差值小於等於 limit 的數對個數
    int countLower(vector<int> &nums, int limit) {
        int count = 0;
        for (int l = 0, r = 0; l < nums.size(); ++l) {
            while (r + 1 < nums.size() && nums[r + 1] - nums[l] <= limit)
                r++;
            count += r - l;
        }
        return count;
    }

    // 時間複雜度 O(n * log(n) + n * log(max-min)),額外空間複雜度 O(1)
    int smallestDistancePair(vector<int> &nums, int k) {
        sort(nums.begin(), nums.end());
        int left = 0;
        int right = nums.back() - nums.front();
        int mid;

        while (left <= right) {
            mid = left + ((right - left) >> 1);
            if (countLower(nums, mid) >= k) {
                right = mid - 1;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }
};

2141. 同時執行 N 臺電腦的最長時間

#include <vector>

using namespace std;

class Solution {
public:
    // 能否讓 computers 臺電腦共同執行 time 分鐘
    bool finished(vector<int> &batteries, int computers, long time) {
        // 碎片電量總和
        long fragmentCharge = 0;
        for (const auto &charge: batteries) {
            if (charge > time) {
                // time 時間內全都給這臺電腦供電,沒有提供碎片電量
                computers--;
            } else {
                // 碎片電量
                fragmentCharge += charge;
            }
            // 碎片電量 >= 臺數 * 要求
            if (fragmentCharge >= (long) computers * time) return true;
        }
        return false;
    }

    // 時間複雜度 O(n * log(sum)),額外空間複雜度 O(1)
    long long maxRunTime(int n, vector<int> &batteries) {
        long sum = 0;
        for (const auto &item: batteries)
            sum += item;
        long left = 0;
        long right = sum;
        long mid;

        while (left <= right) {
            mid = left + ((right - left) >> 1);
            if (finished(batteries, n, mid)) {
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return right;
    }
};
  • 貪心最佳化
#include <vector>

using namespace std;

class Solution {
public:
    // 能否讓 computers 臺電腦共同執行 time 分鐘
    bool finished(vector<int> &batteries, int computers, long time) {
        // 碎片電量總和
        long fragmentCharge = 0;
        for (const auto &charge: batteries) {
            if (charge > time) {
                // time 時間內全都給這臺電腦供電,沒有提供碎片電量
                computers--;
            } else {
                // 碎片電量
                fragmentCharge += charge;
            }
            // 碎片電量 >= 臺數 * 要求
            if (fragmentCharge >= (long) computers * time) return true;
        }
        return false;
    }

    // 時間複雜度 O(n * log(_max)),額外空間複雜度 O(1)
    long long maxRunTime(int n, vector<int> &batteries) {
        long sum = 0;
        int _max = 0;
        for (const auto &item: batteries) {
            sum += item;
            _max = max(_max, item);
        }

        // 最佳化
        if (sum > (long) _max * n) {
            // 所有電池的最大電量是 _max
            // 如果此時 sum > (long) _max * num,
            // 說明: 最終的供電時間一定在 >= max,而如果最終的供電時間 >= max
            // 說明: 對於最終的答案 X 來說,所有電池都是碎片電池
            // 那麼尋找 ? * num <= sum 的情況中,儘量大的 ? 即可
            // 即 sum / num
            return sum / n;
        }
        // 最終的供電時間一定在 < _max 範圍上

        long left = 0;
        long right = _max;
        long mid;

        while (left <= right) {
            mid = left + ((right - left) >> 1);
            if (finished(batteries, n, mid)) {
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        return right;
    }
};

計算等位時間

  • 給定一個陣列 arr 長度為 n,表示 n 個服務員,每服務一個人的時間
  • 給定一個正數 m,表示有 m 個人等位,如果你是剛來的人,每個客人都遵循有空位就上的原則,請問你需要等多久?
  • 假設 m 遠遠大於 n,比如 n <= 10^3, m <= 10^9,該怎麼做是最優解?
package class051;

import java.util.PriorityQueue;

// 計算等位時間
// 給定一個陣列arr長度為n,表示n個服務員,每服務一個人的時間
// 給定一個正數m,表示有m個人等位,如果你是剛來的人,請問你需要等多久?
// 假設m遠遠大於n,比如n <= 10^3, m <= 10^9,該怎麼做是最優解?
// 谷歌的面試,這個題連考了2個月
// 找不到測試連結,所以用對數器驗證
public class Code06_WaitingTime {

    // 堆模擬
    // 驗證方法,不是重點
    // 如果m很大,該方法會超時
    // 時間複雜度O(m * log(n)),額外空間複雜度O(n)
    public static int waitingTime1(int[] arr, int m) {
        // 一個一個物件int[]
        // [醒來時間,服務一個客人要多久]
        PriorityQueue<int[]> heap = new PriorityQueue<>((a, b) -> (a[0] - b[0]));
        int n = arr.length;
        for (int i = 0; i < n; i++) {
            heap.add(new int[]{0, arr[i]});
        }
        for (int i = 0; i < m; i++) {
            int[] cur = heap.poll();
            cur[0] += cur[1];
            heap.add(cur);
        }
        return heap.peek()[0];
    }

    // 二分答案法
    // 最優解
    // 時間複雜度O(n * log(min * w)),額外空間複雜度O(1)
    public static int waitingTime2(int[] arr, int w) {
        int min = Integer.MAX_VALUE;
        for (int x : arr) {
            min = Math.min(min, x);
        }
        int ans = 0;
        for (int l = 0, r = min * w, m; l <= r; ) {
            // m中點,表示一定要讓服務員工作的時間!
            m = l + ((r - l) >> 1);
            // 能夠給幾個客人提供服務
            if (f(arr, m) >= w + 1) {
                ans = m;
                r = m - 1;
            } else {
                l = m + 1;
            }
        }
        return ans;
    }

    // 如果每個服務員工作time,可以接待幾位客人(結束的、開始的客人都算)
    public static int f(int[] arr, int time) {
        int ans = 0;
        for (int num : arr) {
            ans += (time / num) + 1;
        }
        return ans;
    }

    // 對數器測試
    public static void main(String[] args) {
        System.out.println("測試開始");
        int N = 50;
        int V = 30;
        int M = 3000;
        int testTime = 20000;
        for (int i = 0; i < testTime; i++) {
            int n = (int) (Math.random() * N) + 1;
            int[] arr = randomArray(n, V);
            int m = (int) (Math.random() * M);
            int ans1 = waitingTime1(arr, m);
            int ans2 = waitingTime2(arr, m);
            if (ans1 != ans2) {
                System.out.println("出錯了!");
            }
        }
        System.out.println("測試結束");
    }

    // 對數器測試
    public static int[] randomArray(int n, int v) {
        int[] arr = new int[n];
        for (int i = 0; i < n; i++) {
            arr[i] = (int) (Math.random() * v) + 1;
        }
        return arr;
    }

}

刀砍毒殺怪獸問題

package class051;

// 刀砍毒殺怪獸問題
// 怪獸的初始血量是一個整數hp,給出每一回合刀砍和毒殺的數值cuts和poisons
// 第i回合如果用刀砍,怪獸在這回合會直接損失cuts[i]的血,不再有後續效果
// 第i回合如果用毒殺,怪獸在這回合不會損失血量,但是之後每回合都損失poisons[i]的血量
// 並且你選擇的所有毒殺效果,在之後的回合都會疊加
// 兩個陣列cuts、poisons,長度都是n,代表你一共可以進行n回合
// 每一回合你只能選擇刀砍或者毒殺中的一個動作
// 如果你在n個回合內沒有直接殺死怪獸,意味著你已經無法有新的行動了
// 但是怪獸如果有中毒效果的話,那麼怪獸依然會在血量耗盡的那回合死掉
// 返回至少多少回合,怪獸會死掉
// 資料範圍 : 
// 1 <= n <= 10^5
// 1 <= hp <= 10^9
// 1 <= cuts[i]、poisons[i] <= 10^9
// 本題來自真實大廠筆試,找不到測試連結,所以用對數器驗證
public class Code07_CutOrPoison {

    // 動態規劃方法(只是為了驗證)
    // 目前沒有講動態規劃,所以不需要理解這個函式
    // 這個函式只是為了驗證二分答案的方法是否正確的
    // 純粹為了寫對數器驗證才設計的方法,血量比較大的時候會超時
    // 這個方法不做要求,此時並不需要理解,可以在學習完動態規劃章節之後來看看這個函式
    public static int fast1(int[] cuts, int[] poisons, int hp) {
       int sum = 0;
       for (int num : poisons) {
          sum += num;
       }
       int[][][] dp = new int[cuts.length][hp + 1][sum + 1];
       return f1(cuts, poisons, 0, hp, 0, dp);
    }

    // 不做要求
    public static int f1(int[] cuts, int[] poisons, int i, int r, int p, int[][][] dp) {
       r -= p;
       if (r <= 0) {
          return i + 1;
       }
       if (i == cuts.length) {
          if (p == 0) {
             return Integer.MAX_VALUE;
          } else {
             return cuts.length + 1 + (r + p - 1) / p;
          }
       }
       if (dp[i][r][p] != 0) {
          return dp[i][r][p];
       }
       int p1 = r <= cuts[i] ? (i + 1) : f1(cuts, poisons, i + 1, r - cuts[i], p, dp);
       int p2 = f1(cuts, poisons, i + 1, r, p + poisons[i], dp);
       int ans = Math.min(p1, p2);
       dp[i][r][p] = ans;
       return ans;
    }

    // 二分答案法
    // 最優解
    // 時間複雜度O(n * log(hp)),額外空間複雜度O(1)
    public static int fast2(int[] cuts, int[] poisons, int hp) {
       int ans = Integer.MAX_VALUE;
       for (int l = 1, r = hp + 1, m; l <= r;) {
          // m中點,一定要讓怪獸在m回合內死掉,更多回合無意義
          m = l + ((r - l) >> 1);
          if (f(cuts, poisons, hp, m)) {
             ans = m;
             r = m - 1;
          } else {
             l = m + 1;
          }
       }
       return ans;
    }

    // cuts、posions,每一回合刀砍、毒殺的效果
    // hp:怪獸血量
    // limit:回合的限制
    public static boolean f(int[] cuts, int[] posions, long hp, int limit) {
       int n = Math.min(cuts.length, limit);
       for (int i = 0, j = 1; i < n; i++, j++) {
          hp -= Math.max((long) cuts[i], (long) (limit - j) * (long) posions[i]);
          if (hp <= 0) {
             return true;
          }
       }
       return false;
    }

    // 對數器測試
    public static void main(String[] args) {
       // 隨機測試的資料量不大
       // 因為資料量大了,fast1方法會超時
       // 所以在資料量不大的情況下,驗證fast2方法功能正確即可
       // fast2方法在大資料量的情況下一定也能透過
       // 因為時間複雜度就是最優的
       System.out.println("測試開始");
       int N = 30;
       int V = 20;
       int H = 300;
       int testTimes = 10000;
       for (int i = 0; i < testTimes; i++) {
          int n = (int) (Math.random() * N) + 1;
          int[] cuts = randomArray(n, V);
          int[] posions = randomArray(n, V);
          int hp = (int) (Math.random() * H) + 1;
          int ans1 = fast1(cuts, posions, hp);
          int ans2 = fast2(cuts, posions, hp);
          if (ans1 != ans2) {
             System.out.println("出錯了!");
          }
       }
       System.out.println("測試結束");
    }

    // 對數器測試
    public static int[] randomArray(int n, int v) {
       int[] ans = new int[n];
       for (int i = 0; i < n; i++) {
          ans[i] = (int) (Math.random() * v) + 1;
       }
       return ans;
    }

}

相關文章