【資料結構與演算法】蓄水池抽樣演算法(Reservoir Sampling)

gonghr發表於2022-01-17

問題描述

給定一個資料流,資料流長度 N 很大,且 N 直到處理完所有資料之前都不可知,請問如何在只遍歷一遍資料(O(N))的情況下,能夠隨機選取出 m 個不重複的資料。

比較直接的想法是利用隨機數演算法,求 random(N) 得到隨機數,但是題目表明資料流極大,這種大資料量是無法一次都讀到記憶體的,這就意味著不能像陣列一樣根據索引獲取元素。獲取 N 只能對所有資料進行遍歷,耗費時間較大,並且題目強調只能遍歷一遍,意味著不能先獲取到 N ,那麼採用分塊儲存資料的方法也不可取(遍歷不止一遍);如果採用估算,可能導致取樣資料不平均。

蓄水池抽樣演算法

假設資料序列的規模為 n(蓄水池大小),需要取樣的數量的為 k

首先構建一個可容納 k 個元素的陣列,將序列的前 k 個元素放入陣列中。

然後從第 k+1 個元素開始,以 k/n 的概率(n 為當前索引位置)來決定該元素是否被替換到陣列中(陣列中的元素被替換的概率是相同的)。當遍歷完所有元素之後,陣列中剩下的元素即為所需採取的樣本。

證明

  • 對於第 i 個元素(i <= k),該元素被選中的概率為 1,且索引 idx <= k 時,該元素被替換的概率為 0 ,當索引 idx 走到 k + 1 時,第 k + 1 個元素被選中進行替換的概率為 $ \frac{\mathrm{k}}{\mathrm{k}+1}$ ,第 i 個元素被選中的概率為 $ \frac{\mathrm{1}}{\mathrm{k}}$,於是第 i 個元素被第 k + 1 個元素替換的概率為 $ \frac{\mathrm{k}}{\mathrm{k}+1}$ * $ \frac{\mathrm{1}}{\mathrm{k}}$ = $ \frac{\mathrm{1}}{\mathrm{k}+1}$ ,則第 i 個元素不被替換的概率為 1 - $ \frac{\mathrm{1}}{\mathrm{k}+1}$ = $ \frac{\mathrm{k}}{\mathrm{k}+1}$ ,同理,第 k + 2 個元素被選中進行替換的概率為 $ \frac{\mathrm{k}}{\mathrm{k}+2}$ , 第 i 個元素被選中的概率為 $ \frac{\mathrm{1}}{\mathrm{k}}$ ,於是第 i 個元素被第 k + 2 個元素替換的概率為 $ \frac{\mathrm{k}}{\mathrm{k}+2}$ * $ \frac{\mathrm{1}}{\mathrm{k}}$ = $ \frac{\mathrm{1}}{\mathrm{k}+2}$ ,第 i 個元素不被替換的概率為 1 - $ \frac{\mathrm{1}}{\mathrm{k}+2}$ = $ \frac{\mathrm{k}+1}{\mathrm{k}+2} $。以此類推,執行到第 n 步時,被保留的概率 = 被選中的概率 * 不被替換的概率,即:

\[1\times \frac{\mathrm{k}}{\mathrm{k}+1}\times \frac{\mathrm{k}+1}{\mathrm{k}+2}\times \frac{\mathrm{k}+2}{\mathrm{k}+3}\times ...\times \frac{\mathrm{n}-1}{\mathrm{n}}=\frac{\mathrm{k}}{\mathrm{n}} \]

  • 對於第 j 個元素(j > k),該元素被選中的概率為 $ \frac{\mathrm{k}}{\mathrm{j}}$ ,第 j + 1 個元素被選中進行替換的概率為 $ \frac{\mathrm{k}}{\mathrm{j}+1}$,第 j 個元素被選中的概率為 $ \frac{\mathrm{1}}{\mathrm{k}}$,第 j 個元素被替換的概率為 $ \frac{\mathrm{k}}{\mathrm{j}+1}$ * $ \frac{\mathrm{1}}{\mathrm{k}}$ = $ \frac{\mathrm{1}}{\mathrm{j}+1}$,則第 j 個元素不被替換的概率為 1 - $ \frac{\mathrm{1}}{\mathrm{j}+1}$ = $ \frac{\mathrm{j}}{\mathrm{j}+1}$。則執行到第 n 步時,被保留的概率 = 被選中的概率 * 不被替換的概率,即:

\[\frac{\mathrm{k}}{\mathrm{j}}\times \frac{\mathrm{j}}{\mathrm{j}+1}\times \frac{\mathrm{j}+1}{\mathrm{j}+2}\times \frac{\mathrm{j}+2}{\mathrm{j}+3}\times ...\times \frac{\mathrm{n}-1}{\mathrm{n}}=\frac{\mathrm{k}}{\mathrm{n}} \]

所以對於其中每個元素,被保留的概率都為 $ \frac{\mathrm{k}}{\mathrm{n}}$

程式碼實現

public class ReservoirSampling {
    private int[] pool;  // 蓄水池,包含所有資料
    private int size;    // 蓄水池規格
    private Random random;

    public ReservoirSampling(int size) {
        this.size = size;
        random = new Random();
        // 初始化資料
        pool = new int[size];
        for (int i = 0; i < size; i++) {
            pool[i] = i;
        }
    }

    public int[] sampling(int K) {
        int[] result = new int[K];
        for (int i = 0; i < K; i++) { // 前 K 個元素直接放入陣列中
            result[i] = pool[i];
        }

        for (int i = K; i < size; i++) { // K + 1 個元素開始進行概率取樣
            int r = random.nextInt(i + 1); // 索引下標為 i 個資料時第 i + 1 個資料,r = [0,i]
            if (r < K) {                         //  選中概率為 k/i+1
                result[r] = pool[i];
            }
        }

        return result;
    }
}

測試

    public static void main(String[] args) {
        ReservoirSampling test = new ReservoirSampling(1000);
        int[] sampling = test.sampling(5);
        for (int i : sampling) {
            System.out.print(i + " ");
        }
    }
// 輸出 205 907 986 696 443,每次執行結果不同

題目

LeetCode 382. 連結串列隨機節點

LeetCode 382. 連結串列隨機節點

給你一個單連結串列,隨機選擇連結串列的一個節點,並返回相應的節點值。每個節點 被選中的概率一樣 。

實現 Solution 類:

Solution(ListNode head) 使用整數陣列初始化物件。
int getRandom() 從連結串列中隨機選擇一個節點並返回該節點的值。連結串列中所有節點被選中的概率相等。
 

示例:


輸入
["Solution", "getRandom", "getRandom", "getRandom", "getRandom", "getRandom"]
[[[1, 2, 3]], [], [], [], [], []]
輸出
[null, 1, 3, 2, 2, 3]

解釋
Solution solution = new Solution([1, 2, 3]);
solution.getRandom(); // 返回 1
solution.getRandom(); // 返回 3
solution.getRandom(); // 返回 2
solution.getRandom(); // 返回 2
solution.getRandom(); // 返回 3
// getRandom() 方法應隨機返回 1、2、3中的一個,每個元素被返回的概率相等。
 

提示:

連結串列中的節點數在範圍 [1, 104] 內
-104 <= Node.val <= 104
至多呼叫 getRandom 方法 104 次
 

進階:

如果連結串列非常大且長度未知,該怎麼處理?
你能否在不使用額外空間的情況下解決此問題?

解:典型的蓄水池演算法,當 k1 時的特殊情況,每次只取出一個元素。

class Solution {
    ListNode head;
    Random random = new Random();
    public Solution(ListNode _head) {
        this.head = _head;
    }
    
    // 另第 idx 個結點被選中的概率為 1/idx ,則該結點不被後面結點覆蓋的概率為 1/idx * 
    // (1 - 1/(idx+1)) * (1 - 1/(idx+2)) * ...* (1 - 1/n) = 1/n
    // 白話:對第 idx 個結點計算概率,random = [0, idx), 則random = 0 的概率為 1/idx
    // 只要第 idx 個結點的 random 為 0 則選中覆蓋原答案,直到選到最後一個結點。
    public int getRandom() {
        int idx = 1;
        ListNode node = head;
        int ans = node.val;
        while(node != null) {
            if(random.nextInt(idx) == 0) ans = node.val;
            node = node.next;
            idx++;
        }
        return ans;
    }
}

LeetCode 398. 隨機數索引

LeetCode 398. 隨機數索引

給定一個可能含有重複元素的整數陣列,要求隨機輸出給定的數字的索引。 您可以假設給定的數字一定存在於陣列中。

注意:
陣列大小可能非常大。 使用太多額外空間的解決方案將不會通過測試。

示例:

int[] nums = new int[] {1,2,3,3,3};
Solution solution = new Solution(nums);

// pick(3) 應該返回索引 2,3 或者 4。每個索引的返回概率應該相等。
solution.pick(3);

// pick(1) 應該返回 0。因為只有nums[0]等於1。
solution.pick(1);

解:只需要考慮給定數字即可,對遍歷到的給定數字進行編號(1,2,...),再按照蓄水池演算法隨機取出一個即可

class Solution {
    private int[] nums;

    public Solution(int[] nums) {
        this.nums = nums;
    }
    
    public int pick(int target) {
        int ans = 0;
        int idx = 0;
        Random random = new Random();
        for(int i = 0; i < nums.length; i++) {
            if(nums[i] == target) {
                idx++;
                if(random.nextInt(idx) == 0) ans = i;
            }
        }
        return ans;
    }   
}

參考資料

挺有意思的一個視訊

相關文章