問題描述:
在未排序的陣列中找到第 k 個最大的元素。請注意,你需要找的是陣列排序後的第 k 個最大的元素,而不是第 k 個不同的元素。
面試中常考的問題之一,同時這道題由於解法眾多,也是考察時間複雜度計算的一個不錯的問題。
1,選擇排序
利用選擇排序,將陣列中最大的元素放置在陣列的最前端,然後第k次選擇的最大元素就是第K大個元素,直接根據索引返回結果即可。
public class Select { public static void main(String[] args) { int[] arr = new int[]{5,3,2,1,4,7,8,10,6,9}; System.out.println(findKthLargest(arr, 3)); } private static int findKthLargest(int[] arr, int k){ if(k <= 0 || k > arr.length) throw new IllegalArgumentException("k error"); for(int i = 0; i < k; ++i){ int maxNum = Integer.MIN_VALUE; int maxIndex = -1; for(int j = i; j < arr.length; ++j){ if(arr[j] > maxNum){ maxNum = arr[j]; maxIndex = j; } } swap(arr, maxIndex, i); } System.out.println(Arrays.toString(arr)); return arr[k-1]; } private static void swap(int[] arr, int i, int j){ int temp = arr[i]; arr[i] = arr[j]; arr[j] = temp; } }
結果:
[10, 9, 8, 1, 4, 7, 2, 5, 6, 3]
8
我們可以看到陣列經過選擇排序後,前三個元素分別是三趟選擇中最大的元素,直接返回k-1索引位置的元素,即是第K大的元素。
時間複雜度O(n*K),經過K次選擇,每次選擇都要遍歷n個元素。
2,排序優化
上一個方法的本質實際上是將整個陣列進行一個排序,然後根據索引位置得到答案,基於這個情況我們可以使用一些更快速的排序方法,例如選擇排序或歸併排序,以達到平局時間複雜度為O(nlogn)
public class Sort { public static void main(String[] args) { int[] arr = new int[]{5,3,2,1,4,7,8,10,6,9}; System.out.println(findKthLargest(arr, 2)); } private static int findKthLargest(int[] arr, int k){ if(k <= 0 || k > arr.length) throw new IllegalArgumentException("k error"); Arrays.sort(arr); System.out.println(Arrays.toString(arr)); return arr[arr.length-k]; } }
結果:
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
9
時間複雜度O(nlogn),最壞時間複雜度根據不同的排序方法而不一樣,快排的話就是O(n^2),歸併排序是O(nlogn)。
3,堆(優先佇列)
思路是建立一個最小堆,將所有陣列中的元素加入堆中,並保持堆的大小小於等於 k
。這樣,堆中就保留了前 k
個最大的元素。這樣,堆頂的元素就是正確答案。
public class Heap { public static void main(String[] args) { int[] arr = new int[]{5,3,2,1,4,7,8,10,6,9}; System.out.println(findKthLargest(arr, 3)); } private static int findKthLargest(int[] arr, int k){ if(k <= 0 || k > arr.length) throw new IllegalArgumentException("k error"); PriorityQueue<Integer> queue = new PriorityQueue<>((a,b)->{ return a-b; }); for(int num:arr){ queue.offer(num); if(queue.size() > k) queue.poll(); } return queue.peek(); } }
時間複雜度是O(nlogk),向大小為 k 的堆中新增或刪除元素的時間複雜度為O(logk),遍歷n個元素,故總時間複雜度為 O(nlogk)
4,快速選擇
基於快排的思想,選出一個基準元素,將陣列劃分成兩部分,左側的元素都比基準元素大,右側的都比基準元素小,如果基準元素的索引恰好等於k-1,也就是說這個基準元素就是第k大的元素,否則根據基準元素的位置再去左邊或者右邊去選擇。
import java.util.PriorityQueue; import java.util.Random; public class QuickSelect { public static void main(String[] args) { int[] arr = new int[]{5,3,2,1,4,7,8,10,6,9}; System.out.println(findKthLargest(arr, 10)); } private static int findKthLargest(int[] arr, int k){ if(k <= 0 || k > arr.length) throw new IllegalArgumentException("k error"); return quickSelect(arr, 0, arr.length-1, k); } private static int quickSelect(int[] arr, int left, int right, int k){ if(left == right) return arr[left]; Random random_num = new Random(); int pivotIndex = left + random_num.nextInt(right - left); pivotIndex = partition(arr, left, right, pivotIndex); if(pivotIndex == k-1){ return arr[pivotIndex]; }else if(pivotIndex < k-1){ return quickSelect(arr, pivotIndex+1, right, k); }else{ return quickSelect(arr, left, pivotIndex-1, k); } } private static int partition(int[] arr, int left, int right, int pivotIndex){ int pivot = arr[pivotIndex]; swap(arr, pivotIndex, right); int l = left, r = right; while(l < r){ while(l < r && arr[l] >= pivot) l++; if(arr[l] < pivot) swap(arr, l, r); while(l < r && arr[r] <= pivot) r--; if(arr[r] > pivot) swap(arr, l, r); } return l; } private static void swap(int[] arr, int i, int j){ int temp = arr[i]; arr[i] = arr[j]; arr[j] = temp; } }
這裡我們選擇一個陣列中的隨機值作為基準值,如果每次恰好都劃分一半的元素的話,則T(n) = n + n/2 + n/4 + n/8 + n/16 + ... = 2n,也就是O(n)的時間複雜度。
但如果每一次選擇的元素恰好是最小值的話,時間複雜度則退化到了O(n^2)
但是平均時間複雜度是O(n),演算法導論上有嚴格的證明。
5,BFPRT
在BFPRT演算法中,僅僅是改變了快速排序Partion中的pivot值的選取,在快速排序中,我們始終選擇第一個元素或者最後一個元素作為pivot,而在BFPTR演算法中,每次選擇五分中位數的中位數作為pivot,這樣做的目的就是使得劃分比較合理,從而避免最壞情況的發生。演算法步驟如下:
- 將輸入陣列的n個元素劃分為n/5組,每組5個元素,且至多隻有一個組由剩下的n%5個元素組成。
- 尋找n/5個組中每一個組的中位數,首先對每組的元素進行插入排序,然後從排序過的序列中選出中位數。
- 對於2中找出的n/5箇中位數,遞迴進行步驟1和2,直到只剩下一個數即為這n/5個元素的中位數,找到中位數後並找到對應的下標p。
- 進行Partion劃分過程,Partion劃分中的pivot元素下標為p。
- 進行高低區判斷即可
本演算法的最壞時間複雜度為O(n),值得注意的是通過BFPTR演算法將陣列按第K小(大)的元素劃分為兩部分,而這高低兩部分不一定是有序的,通常我們也不需要求出順序,而只需要求出前K大的或者前K小的。
public class BFPRT { public static void main(String[] args) { int[] arr = new int[]{3,2,3,1,2,4,5,5,6}; System.out.println(findKthLargest(arr, 4)); } private static int findKthLargest(int[] arr, int k){ if(k <= 0 || k > arr.length) throw new IllegalArgumentException("k error"); return quickSelect(arr, 0, arr.length-1, k); } private static int findMedian(int[] arr, int l, int r){ int i = l, index = 0; for(; i + 4 <= r; i += 5, index++){ sort(arr, i, i + 4); swap(arr, l + index, i + 2); } if(i <= r){ sort(arr, i, r); swap(arr, l+index, i + (r-i+1) / 2); //如果是最後陣列元素是偶數選擇較小的一個 index++; } if(index == 1) return l; else return findMedian(arr, l, l+index-1); } private static int quickSelect(int[] arr, int left, int right, int k){ if(left == right) return arr[left]; // Random random = new Random(); // int pivotIndex = left + random.nextInt(right - left); int pivotIndex = findMedian(arr, left, right); pivotIndex = partition(arr, left, right, pivotIndex); if(pivotIndex == k-1){ return arr[pivotIndex]; }else if(pivotIndex < k-1){ return quickSelect(arr, pivotIndex+1, right, k); }else{ return quickSelect(arr, left, pivotIndex-1, k); } } private static int partition(int[] arr, int left, int right, int pivotIndex){ int pivot = arr[pivotIndex]; swap(arr, pivotIndex, right); int l = left, r = right; while(l < r){ while(l < r && arr[l] >= pivot) l++; if(arr[l] < pivot) swap(arr, l, r); while(l < r && arr[r] <= pivot) r--; if(arr[r] > pivot) swap(arr, l, r); } return l; } private static void swap(int[] arr, int i, int j){ int temp = arr[i]; arr[i] = arr[j]; arr[j] = temp; } public static void sort(int[] arr, int l, int r){ for(int i = l; i <= r; i++){ for(int j = i+1; j <= r; j++){ if(arr[j] < arr[i]) swap(arr, i, j); } } } }