TimSort原始碼詳解

he_jia發表於2020-12-11

Python的排序演算法由Peter Tim提出,因此稱為TimSort。它最先被使用於Python語言,後被多種語言作為預設的排序演算法。TimSort實際上可以看作是mergeSort+binarySort,它主要是針對歸併排序做了一系列優化。如果想看Python的TimSort原始碼,在Cpython的Github倉庫能找到,這裡面還包含一個List物件的PyList_Sort函式。這篇文章為了方便借用JAVA對TimSort的實現原始碼來說明其原理。

一.binarySort函式

TimSort非常適合大量資料的排序,對於少量資料的排序,TimSort選擇使用binarySort來實現,因此我想先介紹一下binarySort的過程。

我們知道插入排序的思路是通過交換元素位置的方式依次插入元素(如果不太瞭解插入排序可以先去熟悉一下),當要插入元素時,從已排序的部分的最後一位開始,依次比較其與待插入的元素的值,這樣來找到待插入元素的位置。顯然,在插入排序的過程中,始終是有一個在增長的有序部分和在縮短的無序部分。排序過程見下圖(圖源自部落格):

 

 但是插入排序有個很明顯的問題,在找當前元素的位置時它是一步一步地在有序部分往前推進的,而有序列表的插入可以通過二分法來減少比較次數,這和二分查詢的目的不同但是思路相同(可以自己嘗試一下實現它),我們稱其為二分插入,通過二分插入實現的排序就是二分排序(binarySort)。我們可以看一下它的Java原始碼:

//a是陣列,lo是待排序部分(有序部分+無序部分)的最低位(包含),hi是最高位(不包含),start是無序部分的最低位,c是比較函式即排序的依據
private static <T> void binarySort(T[] a, int lo, int hi, int start, Comparator<? super T> c) {
    assert lo <= start && start <= hi;
    if (start == lo)
        start++;
    for ( ; start < hi; start++) {//接下來就是二分插入的過程
        T pivot = a[start];
        int left = lo;
        int right = start;
        assert left <= right;
        while (left < right) {
            int mid = (left + right) >>> 1;
            if (c.compare(pivot, a[mid]) < 0)
                right = mid;
            else
                left = mid + 1;
        }
        assert left == right;
        int n = start - left;//n表示要移動的元素數量
        //優化插入過程,當要移動的元素數量為1或2時,可以直接交換元素位置;
        //否則將left後的元素往後挪一位再插入,方式是通過arraycopy函式複製
        switch (n) {
            case 2:  a[left + 2] = a[left + 1];
            case 1:  a[left + 1] = a[left];
                break;
            default: System.arraycopy(a, left, a, left + 1, n);
        }
        a[left] = pivot;
    }
}

二.run

這是TimSort中最重要的一個概念,實在找不到合適的翻譯(無奈臉)。run實際上就是一個連續上升(包含相等)或者下降(不包含相等)的子串。比如對於陣列[1,3,2,4,6,4,7,7,3,2],其中有四個run,第一個是[1,3],第二個是[2,4,6],第三個是[4,7,7],第四個是[3,2],在函式中對於單調遞減的run會被反轉成遞增的序列。原始碼中通過countRunAndMakeAscending()函式來得到run:

private static <T> int countRunAndMakeAscending(T[] a, int lo, int hi, Comparator<? super T> c) {
    assert lo < hi;
    int runHi = lo + 1;
    if (runHi == hi)
        return 1;

    //找到run的結束位置,如果是下降的序列將其反轉
    if (c.compare(a[runHi++], a[lo]) < 0) {
        while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) < 0)
            runHi++;
        reverseRange(a, lo, runHi);
    } else {
        while (runHi < hi && c.compare(a[runHi], a[runHi - 1]) >= 0)
            runHi++;
    }

    return runHi - lo;//返回值為run的長度
}

 三.TimSort排序過程

直接上原始碼分析,可以參考程式碼註釋和下面的解釋來閱讀:

static <T> void sort(T[] a, int lo, int hi, Comparator<? super T> c, T[] work, int workBase, int workLen) {
    assert c != null && a != null && lo >= 0 && lo <= hi && hi <= a.length;

    int nRemaining  = hi - lo;//待排序的陣列長度
    if (nRemaining < 2)
        return;  //長度為0或1的陣列無需排序

    // 如果陣列長度小於32(即MIN_MERGE,TimSort的Python版本里這個值為64),直接用binarySort排序
    if (nRemaining < MIN_MERGE) {
        int initRunLen = countRunAndMakeAscending(a, lo, hi, c);//找到第一個run,返回其長度
        binarySort(a, lo, hi, lo + initRunLen, c);//第一個run已排好序,因此binarySort的引數start=lo+initRunLen
        return;
    }

    TimSort<T> ts = new TimSort<>(a, c, work, workBase, workLen);
    int minRun = minRunLength(nRemaining);//最小run長度,見解釋A
    do {
        // 找run
        int runLen = countRunAndMakeAscending(a, lo, hi, c);

        // 如果run長度小於minRun,將其擴充套件為min(nRemaining,minRun)
        if (runLen < minRun) {
            int force = nRemaining <= minRun ? nRemaining : minRun;
            binarySort(a, lo, lo + force, lo + runLen, c);//擴充套件run到長度force
            runLen = force;
        }

        ts.pushRun(lo, runLen);// 將run儲存到棧中,見解釋B
        ts.mergeCollapse();// 根據規則合併相鄰的run,見解釋C

        // 繼續尋找run
        lo += runLen;
        nRemaining -= runLen;
    } while (nRemaining != 0);

    // Merge all remaining runs to complete sort
    assert lo == hi;
    ts.mergeForceCollapse();//最後收尾,將棧中所有run從棧頂開始依次鄰近合併,得到一個run
    assert ts.stackSize == 1;
}

 

解釋A:在執行排序演算法之前,會計算minRun的值,minRun會從[16,32]區間中選擇一個數字,使得陣列的長度除以minRun等於或者略小於2的冪次方。比如長度是65,那麼minrun的值就是17;如果長度是174minrun就是22。minRunLength()函式程式碼如下:

private static int minRunLength(int n) {
    assert n >= 0;
    int r = 0;      // 如果n的低位有任何一位為1,r就會置1
    while (n >= 32) {
        r |= (n & 1);
        n >>= 1;
    }
    return n + r;
}

解釋B:存run是通過兩個棧,分別儲存run的起始位置和長度,可以看pushRun()函式程式碼:

private int stackSize = 0;  // 棧中run的數量
private final int[] runBase;
private final int[] runLen;

private void pushRun(int runBase, int runLen) {
     this.runBase[stackSize] = runBase;
     this.runLen[stackSize] = runLen;
     stackSize++;
}

解釋C:這裡的合併規則如下:假設棧頂三個run依次為X,Y,Z,X為棧頂run,要求它們的長度滿足X+Y<Z及X<Y兩個條件。其實這就是TimSort演算法的精髓所在了,它通過這樣的方式盡力保證合併的平衡性,即讓待合併的兩個陣列儘可能長度接近,從而提高合併的效率。通過這兩個條件限制,保證了棧中的run從棧底到棧頂是從大到小排列的,並且合併的收斂速度與斐波那契數列一樣。可以看mergeCollapse()函式程式碼:

private void mergeCollapse() {
    while (stackSize > 1) {
        int n = stackSize - 2;
        if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {//條件一不滿足的話,Y就會和X、Z中較小的run合併
            if (runLen[n - 1] < runLen[n + 1])
                n--;
            mergeAt(n);
        } else if (runLen[n] <= runLen[n + 1]) {//條件二不滿足的話,Y就和X合併
            mergeAt(n);
        } else {
            break; // Invariant is established
        }
    }
}

四.合併的方式

到這裡我們就把整個流程講完了,還有最後一個問題沒有講--如何合併run?合併兩個run需要額外空間(可以不用,但是效率太低),額外空間大小我們可以設為較小的run的長度。假設我們有前後X、Y兩個run需要合併,X較小,那麼X可以放入臨時記憶體中,然後從小到大合併;如果Y較小,那麼把Y放入臨時記憶體,然後從大到小排序。這個流程其實也比較簡單(圖源自佛西先森部落格):

 

 

 並且,由於兩個run都是已經排好序的序列,我們可以在run合併之前計算A中最後一個元素在B中的位置i,那麼B中i之後的元素都不需要參與合併;同理,我們也可以計算B中第一個元素在A中位置j,A中j之前的元素都不需要參與合併。

在歸併排序演算法中合併兩個陣列就是一一比較每個元素,把較小的放到相應的位置,然後比較下一個,這樣有一個缺點就是如果A中如果有大量的元素A[i...j]是小於B中某一個元素B[k]的,程式仍然會持續的比較A[i...j]中的每一個元素和B[k],增加合併過程中的時間消耗。

為了優化合並的過程,TimSort設定了一個閾值MIN_GALLOP,如果A中連續MIN_GALLOP個元素比B中某一個元素要小,則通過二分搜尋找到A[0]B中的位置i0,把Bi0之前的元素直接放入合併的空間中,然後再在A中找到B[i0]所在的位置j0,把Aj0之前的元素直接放入合併空間中,如此迴圈直至在AB中每次找到的新的位置和原位置的差值是小於MIN_GALLOP的,這才停止然後繼續進行一對一的比較。

五.總結

總結一下上面的排序的過程:

  1. 如果長度小於32直接進行二分插入排序
  2. 遍歷陣列組成一個run
  3. 得到一個run之後會把他放入棧中
  4. 如果棧頂部幾個的run符合合併條件,就會合並相鄰的兩個run
  5. 合併會使用盡量小的記憶體空間和GALLOP模式來加速合併

參考資料:1.世界上最快的排序演算法——Timsort

      2.JDK8官方原始碼

相關文章