尋找陣列中第K大的元素

zhong0316發表於2019-03-01

問題

Find the kth largest element in an unsorted array. Note that it is the kth largest element in the sorted order, not the kth distinct element.

Example 1:

Input: [3,2,1,5,6,4] and k = 2
Output: 5
Example 2:

Input: [3,2,3,1,2,4,5,5,6] and k = 4
Output: 4
Note: 
You may assume k is always valid, 1 ≤ k ≤ array's length.
複製程式碼

tag: Medium

分析

這題最簡單的做法是將陣列排序,然後直接返回第K大的元素。複雜度為:O(NlogN)。但是,很明顯,出題者並不想讓我們這麼做。

如果對陣列排序,演算法的複雜度起碼是 O(NlogN)。那麼如果我們不排序,能不能求出第K大元素呢?答案是可以的,我們知道快速排序中有一個步驟是 partition。它選擇一個元素作為樞紐(pivot),將所有小於樞紐的元素放到樞紐的左邊,將所有大於樞紐的元素放到樞紐的右邊。然後返回樞紐在陣列中的位置。那麼,關鍵就在這裡了。如果此時返回的樞紐元素在陣列中的位置剛好是我們所要求的位置,問題就能得到解決了。

我們先選取一個樞紐元素,從陣列的第一個元素開始到最後一個元素結束。將大於樞紐的左邊元素和小於樞紐的右邊元素交換。然後判斷當前樞紐元素的 index 是否為 n - k,如果是則直接返回這個樞紐元素。如果 index 大於 n - k,則我們從樞紐的左邊繼續遞迴尋找,如果 index 小於 n - k,則我們從樞紐的右邊繼續遞迴尋找。

程式碼

public class KthLargestElementInArray {

    public int findKthLargest(int[] nums, int k) {
        return findKth(nums, nums.length - k /* 第k大,也就是排序後陣列中 index 為 n - k 的元素*/, 0, nums.length - 1);
    }

    public int findKth(int[] nums, int k, int low, int high) {
        if (low >= high) {
            return nums[low];
        }
        int pivot = partition(nums, low, high); // 樞紐元素的 index
        if (pivot == k) {
            return nums[pivot];
        } else if (pivot < k) { // 樞紐元素的 index 小於 k,繼續從樞紐的右邊部分找
            return findKth(nums, k, pivot + 1, high);
        } else { // 樞紐元素的 index 大於 k,繼續從樞紐的左邊部分找
            return findKth(nums, k, low, pivot - 1);
        }
    }

    int partition(int[] nums, int low, int high) {
        int i = low;
        int j = high + 1;
        int pivot = nums[low];
        while (true) {
            while (less(nums[++i], pivot)) {
                if (i == high) {
                    break;
                }
            }
            while (less(pivot, nums[--j])) {
                if (j == low) {
                    break;
                }
            }
            if (i >= j) {
                break;
            }
            swap(nums, i, j);
        }
        swap(nums, low, j);
        return j;
    }

    boolean less(int i, int j) {
        return i < j;
    }

    void swap(int[] nums, int i, int j) {
        if (i == j) {
            return;
        }
        int temp = nums[i];
        nums[i] = nums[j];
        nums[j] = temp;
    }
}
複製程式碼

複雜度分析

複雜度為:O(N) + O(N / 2) + O(N / 4) + ... 最終演算法複雜度為 O(N)

相關文章