sort 包原始碼分析

broqiang發表於2019-06-26

最近自定義 sort 排序的時候, 發現了個問題, 就是根據 struct 的一個欄位排序的時候, 排序完了
相對位置被改變了, 最後只能把多個欄位都比較下, 才能保證位置。 就叫我產生了檢視下原始碼的衝動,
看看到底是怎麼回事。

這只是個人的一些觀點, 如果有不對的地方歡迎指出
issues 或直接點選修改原文。

看完一遍程式碼, 感覺思想很重要,這個排序快把我熟悉或知道的排序演算法全都用了一遍,
個人建議在看這個原始碼前要了解: 插入排序、希爾排序、堆排序、快速排序、選擇排序,
因為在原始碼中都使用了, 否則看起來會有點困難。

開始

找到原始碼中的 sort 包, 檢視下面包含了這些檔案:

example_interface_test.go
example_keys_test.go
example_multi_test.go
example_search_test.go
example_test.go
example_wrapper_test.go
export_test.go
genzfunc.go
search.go
search_test.go
slice.go
sort.go
sort_test.go
zfuncversion.go

其中 example_* 的都是示例, 也都很必要全都看一遍, 可以更容易理解這個包怎麼去用。

主要程式碼是再 sort.go 這個檔案中, 在頂部定義了一個介面, 只要實現了這個介面(包括自定義的),
就可以使用 sort 包中的排序。

type Interface interface {
    // Len is the number of elements in the collection.
    // 集合中元素的個數, 這裡一般會很簡單, 只要返回個長度就可以了
    Len() int
    // Less reports whether the element with
    // index i should sort before the element with index j.
    // Less 函式判斷下標 i 的元素是否應該放在下標 j 的前面
    Less(i, j int) bool
    // Swap swaps the elements with indexes i and j.
    // 交換下標 i j 對應的元素
    Swap(i, j int)
}

Sort 函式

排序演算法穩定性 參考這裡

它的入口有兩個函式, func Sort(data Interface)func Stable(data Interface) ,
Sort 是不穩定排序, Stable 是穩定排序, 不過從它的演算法使用上來看, Sort 的速度會比 Stable
要快, 一般優先使用 Sort 。

Sort 和 Stable 傳入的都是一個 Interface 型別的引數, 所以我們可以將準備排序資料集實現
Interface 就可以使用這兩個方法來排序。先看下 Sort , 後面再說 Stable 。

func Sort(data Interface) {
    n := data.Len() // 實現的 Interface 介面的函式 Len 返回的長度
    // 呼叫快速排序函式, 但是它不僅僅用了快排的演算法, 後面會說
    quickSort(data, 0, n, maxDepth(n))
}

上面函式中的 maxDepth 是快排遞迴的最大深度,返回的值為 2*ceil(lg(n+1)),
他是一個快速排序切換堆排序的閥值, 在檢視 quickSort 函式時候再說明。

// 這個函式是根據資料集的長度來計算遞迴的深度
func maxDepth(n int) int {
    var depth int
    // 這裡每次迴圈 i 右移一位, 相當於 i 每次迴圈都除以 2 , i /= 2
    // 個人覺得使用位運算效能會高一些吧。
    for i := n; i > 0; i >>= 1 {
        depth++
    }
    return depth * 2
}

這裡有個小疑惑, 不明白為什麼要乘以 2, 有人明白可以告訴我下

quickSort

這是一個重量級的函式, Sort 函式中的主要實現都在這個函式中。

下面的分析中假設我們是從小到大排序的, 這樣有些地方描述起來容易一些。

func quickSort(data Interface, a, b, maxDepth int) {
    // 這個函式第一次進入的時候 a = 0, b = 資料集長度, 外面的 n
    // 暫時還不知道 a 是什麼, 繼續往下看
    // 看下面註釋的說明, 如果資料集的長度大於 12, 就會進入這個迴圈,
    // 否則,就使用希爾排序, 先看下進入之後做了什麼
    for b-a > 12 { // Use ShellSort for slices <= 12 elements
        // 如果遞迴到了最大深度, 就使用堆排序
        if maxDepth == 0 {
            // 呼叫堆排序函式, 一會再看這個函式
            heapSort(data, a, b)
            return
        }
        // 迴圈一次, 最大深度 -1, 相當於又深入(遞迴)了一層
        maxDepth--
        // 這個是求中位數的函式, 看到這裡大概就明白了 a 和 b 是什麼了
        // a 是資料集的左邊, b 是資料集的右邊, 熟悉快排的應該就明白了
        // 這個就是通過資料集, 左右邊, 來求中位數, 這裡返回了兩個變數,
        // 只能暫停一下, 先看下 doPivot 的實現了
        // 看完 doPivot 再回到這裡參考一下, 就可以知道了:
        // doPivot 它取一點為軸,把不大於中位數的元素放左邊,大於軸的元素放右邊,
        // 返回小於中位數部分資料的最後一個下標,以及大於軸部分資料的第一個下標。
        mlo, mhi := doPivot(data, a, b)
        // Avoiding recursion on the larger subproblem guarantees
        // a stack depth of at most lg(b-a).
        // 因為迴圈肯定比遞迴呼叫節省時間,但是兩個子問題只能一個進行迴圈,另一個只能用遞迴。
        // 這裡是把較小規模的子問題進行遞迴,較大規模子問題進行迴圈。
        if mlo-a < b-mhi {
            quickSort(data, a, mlo, maxDepth)
            a = mhi // i.e., quickSort(data, mhi, b)
        } else {
            quickSort(data, mhi, b, maxDepth)
            b = mlo // i.e., quickSort(data, a, mlo)
        }
    }

    // 如果元素的個數小於 12 個(無論是遞迴的還是首次進入), 就先使用希爾排序
    // 然後再呼叫插入排序。
    if b-a > 1 {
        // Do ShellSort pass with gap 6
        // It could be written in this simplified form cause b-a <= 12
        for i := a + 6; i < b; i++ {
            if data.Less(i, i-6) {
                data.Swap(i, i-6)
            }
        }
        insertionSort(data, a, b)
    }
}

下面再去看看上面沒有說的堆排序和最後出現的插入排序。

doPivot 函式

這個函式看起來還有點長, 一點一點看吧, 快排的難點就在這裡了, 看長度看這裡的實現也比較複雜,
如果快速排序的演算法熟悉的話, 這個函式可以很容易看明白, 如果看不明白就先去看下快速排序演算法。

func doPivot(data Interface, lo, hi int) (midlo, midhi int) {
    // 這裡的 lo 和 hi 就是傳入的 a 和 b, 代表左右邊

    // 這裡應該是取左右邊的中間點, 下面註釋說這樣寫是為了避免整數溢位
    // 不過這個寫法挺巧妙的, 又學了一招
    m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow.
    if hi-lo > 40 {
        // Tukey's ``Ninther,'' median of three medians of three.
        // 這裡求中位數使用的是 Tukey's ninther , 下面有專門解釋的連結
        // 看完 Tukey's ninther 就應該知道這裡做的是什麼了吧,
        // 當兩邊間的元素超過 40 個的時候, 通過 9/3/3 來找到中位數

        // s 的位置是右邊 - 左邊 處理 8 這個位置, 暫時沒明白為什麼是 8,
        // 有人明白也可以告訴下我
        s := (hi - lo) / 8

        // medianOfThree 這個函式比較簡單, 根據傳入的位置, 進行比較,
        // 然後按照位置將元素交換, 如果糾結的話可以先到下面看這個函式的分析
        // 這一次執行完, lo 已經是三個數的中位數
        medianOfThree(data, lo, lo+s, lo+2*s)
        // 這裡執行完, m 已經是這三個數的中位數
        medianOfThree(data, m, m-s, m+s)
        // 這裡執行完, hi-1 已經是三個數的中位數(因為 hi 是傳入的 n,資料集的長度,
        // 下標從 0 開始, 所以要 hi - 1
        medianOfThree(data, hi-1, hi-1-s, hi-1-2*s)
    }

    // 將三次中位數的結果再次求中位數, 相當於是用 3*3 個數來確定中位數
    medianOfThree(data, lo, m, hi-1)

    // Invariants are:
    //  data[lo] = pivot (set up by ChoosePivot)
    //  data[lo < i < a] < pivot
    //  data[a <= i < b] <= pivot
    //  data[b <= i < c] unexamined
    //  data[c <= i < hi-1] > pivot
    //  data[hi-1] >= pivot
    pivot := lo // 中位數定義為 lo
    a, c := lo+1, hi-1 // 將定義好的左邊右移一位, 右邊左移一位

    // 說實話, 這裡用 a 和 c 這種變數, 還要向上去看程式碼才能知道是什麼, 有點囧
    // 將左邊和中位數進行比較, 一直到不滿足條件為止
    for ; a < c && data.Less(a, pivot); a++ {
    }

    // 此時將 a 的位置賦值給 b(又是這樣的變數……)
    b := a
    for {
        // 感覺這個和上面的 a<c 的迴圈做的是一個事,取反比較, 做了這一步應該是
        // 更嚴謹一些吧, 沒有想到什麼情況下能進入到這個迴圈
        for ; b < c && !data.Less(pivot, b); b++ { // data[b] <= pivot
        }
        // 用右邊和中間數做比較, 不滿足 Less 的時候停止
        for ; b < c && data.Less(pivot, c-1); c-- { // data[c-1] > pivot
        }

        // 比較小, 如果左邊和右邊重合或者已經再右邊的右側,就證明中間數左側的資料
        // 全都是比右側的小, 結束迴圈, 完成關於這個中位數的排序
        if b >= c {
            break
        }

        // 如果左側的資料大於右側, 就將資料交換, 完成排序, 再各自移動一位進行下一輪比較
        // data[b] > pivot; data[c-1] <= pivot
        data.Swap(b, c-1)
        b++
        c--
    }
    // 這裡它說如果傳入進來的右邊 - 處理完了的右邊界小於 3, 會出現重複, 保守一點,將比較
    // 的邊界設定為 5 (暫時沒明白什麼意思, 繼續往下看吧)
    // If hi-c<3 then there are duplicates (by property of median of nine).
    // Let's be a bit more conservative, and set border to 5.
    protect := hi-c < 5
    // protect 取反了, 就是這個值是大於 5 的
    // 並且傳入的右邊界 - 當前的右邊 < 全部元素 / 4 (又沒明白在做什麼)
    if !protect && hi-c < (hi-lo)/4 {
        // Lets test some points for equality to pivot
        // 用一些特殊的點和中間數進行比較
        dups := 0
        // 用中位數和右邊界的值比較, 如果中位數比右邊界大, 就交換
        if !data.Less(pivot, hi-1) { // data[hi-1] = pivot
            data.Swap(c, hi-1)
            // 當前右邊界向右移動一位(猜測是因為交換過來的沒有進行過比較)
            c++
            dups++
        }
        // 如果移動後的左邊界比中間數大(此時中間數有可能已經是上面交換完了的)
        // 就將當前的左邊界向左移動一位
        if !data.Less(b-1, pivot) { // data[b-1] = pivot
            b--
            dups++
        }
        // m-lo = (hi-lo)/2 > 6
        // b-lo > (hi-lo)*3/4-1 > 8
        // ==> m < b ==> data[m] <= pivot
        // 用整個集合的中間數和求出的中間數進行比較, 如果比它大, 就交換,
        // 並且將當前的左邊界再向左移動一位
        if !data.Less(m, pivot) { // data[m] = pivot
            data.Swap(m, b-1)
            b--
            dups++
        }
        // if at least 2 points are equal to pivot, assume skewed distribution
        // 如果上面的 if 進入了兩次, 就證明現在是偏態分佈(也就是左右不平衡的)
        protect = dups > 1
    }
    // 如果現在是不平衡的, 再次處理,將資料集平衡
    if protect {
        // Protect against a lot of duplicates
        // Add invariant:
        //  data[a <= i < b] unexamined
        //  data[b <= i < c] = pivot
        for {
            for ; a < b && !data.Less(b-1, pivot); b-- { // data[b] == pivot
            }
            for ; a < b && data.Less(a, pivot); a++ { // data[a] < pivot
            }
            if a >= b {
                break
            }
            // data[a] == pivot; data[b-1] < pivot
            data.Swap(a, b-1)
            a++
            b--
        }
    }
    // Swap pivot into middle
    data.Swap(pivot, b-1)
    // 最後返回的是處理完的左邊界和右邊界移動後的位置, 相當於都是中間數吧?
    // 因為它都是向中間移動的(這個暫時也是猜測的)
    // 這個函式大體看明白了, 就是計算中位數,然後各種移動,完成排序,
    // 不過還是有好多地方暈暈的, 如果下面的 b 和 c 換成有意義的變數,
    // 可能我就能確定它是什麼了, 不是隻能猜了。
    return b - 1, c
}

上面求中位數使用的是 Tukey's ninther 演算法, 也叫 median of medians ,後面是連結,
感興趣可以直接去看下:
Tukey's ninther 演算法
中文翻譯

medianOfThree 函式

func medianOfThree(data Interface, m1, m0, m2 int) {
    // 通過呼叫函式, 我們可以清楚, 這裡面的 m1 m0 m2 分別對應的是資料集的索引(位置)
    // 這個函式比較簡單, 就是將這個三個位置對應的值通過 Less 函式進行比較, 然後排序
    // Less 函式是 Interface 介面對應的方法的實現

    // sort 3 elements
    if data.Less(m1, m0) {
        data.Swap(m1, m0)
    }
    // data[m0] <= data[m1]
    if data.Less(m2, m1) {
        data.Swap(m2, m1)
        // data[m0] <= data[m2] && data[m1] < data[m2]
        if data.Less(m1, m0) {
            data.Swap(m1, m0)
        }
    }
    // 最終處理完的結果是, 第一個傳入的數字在中間位置
    // now data[m0] <= data[m1] <= data[m2]
}

插入排序

先來看下這個程式碼的, 這個程式碼比較簡單, 這貌似也沒什麼分析的, 就是一個最基礎的插入排序。

func insertionSort(data Interface, a, b int) {
    for i := a + 1; i < b; i++ {
        for j := i; j > a && data.Less(j, j-1); j-- {
            data.Swap(j, j-1)
        }
    }
}

堆排序

它主要用了兩個函式 heapSortsiftDown, 又都不太長, 就把它們兩個都放進來了,
先看下下面的 heapSort 。 程式碼註釋中說建立一個最大堆, 這裡就不糾結是最大還是最小,
因為最終的判斷條件還是依賴 Less 函式的比較結果, 分析的時候就按照最大堆來描述。

// siftDown implements the heap property on data[lo, hi).
// first is an offset into the array where the root of the heap lies.
// 這個函式是用來建堆, first 和堆排序本身沒有關係, 因為這裡的陣列不一定是從 0 開始的,
// 所以需要有 first 來做偏移量, 比如 data[1], 就要是 data[first+1]
func siftDown(data Interface, lo, hi, first int) {
    // 這裡看 lo 是堆的根節點
    root := lo
    for {
        // 左子節點的下標(因為最大堆是一個完全二叉樹,所以可以確定出陣列對應的下標)
        child := 2*root + 1
        if child >= hi {
            break // 如果左子節點的下標超出陣列邊界, 就停止
        }

        // child+1 是右子節點
        // 如果右子節點沒有越界, 找出左右子節點中大的那一個
        // 這個稍微有點繞, 如果 child(左), 比 child(右)大, child++
        // child 本身就變成了下一個元素
        if child+1 < hi && data.Less(first+child, first+child+1) {
            child++
        }

        // 再用 child 和根節點比較, 如果根節點大於子節點, 就可以退出此次插入
        if !data.Less(first+root, first+child) {
            return
        }

        // 如果根節點比子節點小, 就將子節點和根節點互換
        data.Swap(first+root, first+child)
        // 上面三步執行完, 就挑出了父節點,左右子節點間的交換, 保證根節點是最大的

        // 資料交換過, child 不一定是它的子節點中最大的了, 將child賦給 root,
        // 和它的子節點再比較, 直到滿足最大堆的結構
        root = child
    }
}

func heapSort(data Interface, a, b int) {
    first := a // 左側邊界
    lo := 0 // 堆的根節點
    hi := b - a // 右側邊界 - 左側邊界,用來元素計數

    // Build heap with greatest element at top.
    // 這裡要先建立一個最大堆(或最小堆, 根據 Less 函式的實現)
    for i := (hi - 1) / 2; i >= 0; i-- {
        // 這裡將陣列中的資料建立成一個堆的結果
        siftDown(data, i, hi, first)
    }

    // Pop elements, largest first, into end of data.
    // 注意這個迴圈是從後向前迴圈的
    for i := hi - 1; i >= 0; i-- {
        // 將第一個元素(堆頂的元素) 和 最後一個元素交換
        // 每一輪迴圈都將最大的元素放在了最後,下一輪最大元素都是前面一個
        // 這樣就可以保證不會影響已經排序好的位置
        data.Swap(first, first+i)

        // 再次維護最大堆的結構
        siftDown(data, lo, i, first)
    }
}

到這裡 Sort 這個入口就已經分析完了, 這個函式開始的排序是不穩定的, 如果想要排序的結果是穩定排序,
就要去分析下 Stable 這個函式

Stable 函式

func Stable(data Interface) {
    stable(data, data.Len())
}

這個函式就比較簡單了, 直接只是呼叫了一下 stable 函式, 繼續檢視 stable 函式:

func stable(data Interface, n int) {
    // 初始 blockSize 設定為 20
    blockSize := 20 // must be > 0
    a, b := 0, blockSize
    // 將切片按照每個 20 分成多個塊, 然後對每個塊進行插入排序
    for b <= n {
        insertionSort(data, a, b)
        a = b
        b += blockSize
    }
    // 這一個是對不到 20 的資料進行排序
    insertionSort(data, a, n)

    for blockSize < n {
        a, b = 0, 2*blockSize
        // 每次將兩個 block 進行排序
        for b <= n {
            // 呼叫歸併排序, 一會再看
            symMerge(data, a, a+blockSize, b)
            // 這裡定義下標的偏移, 下一次迴圈就是下一組 block
            a = b
            b += 2 * blockSize
        }
        // 將剩餘的元素排序
        if m := a + blockSize; m < n {
            symMerge(data, a, m, n)
        }

        // block 每次迴圈擴大兩倍, 直到比元素的總個數大,就結束
        blockSize *= 2
    }
}

歸併排序

func symMerge(data Interface, a, m, b int) {
    // Avoid unnecessary recursions of symMerge
    // by direct insertion of data[a] into data[m:b]
    // if data[a:m] only contains one element.
    // 為了避免不必要的遞迴,當 data[a:m](第一個 block ) 或者 data[m:b](第 2 個 block )
    // 只有一個元素時,直接插入到另一個子陣列中的對應位置。
    if m-a == 1 {
        // Use binary search to find the lowest index i
        // such that data[i] >= data[a] for m <= i < b.
        // Exit the search loop with i == b in case no such index exists.
        i := m
        j := b
        // 這裡是找到一箇中位數, 進行二分查詢
        // 因為前面經過了插入排序, 所以可以保證每一個 block 中的資料都是有序的
        for i < j {
            h := int(uint(i+j) >> 1)
            if data.Less(h, a) {
                i = h + 1
            } else {
                j = h
            }
        }
        // Swap values until data[a] reaches the position before i.
        for k := a; k < i-1; k++ {
            data.Swap(k, k+1)
        }
        return
    }

    // Avoid unnecessary recursions of symMerge
    // by direct insertion of data[m] into data[a:m]
    // if data[m:b] only contains one element.
    // 這裡和上面相同, 只是後一半插入前一半
    if b-m == 1 {
        // Use binary search to find the lowest index i
        // such that data[i] > data[m] for a <= i < m.
        // Exit the search loop with i == m in case no such index exists.
        i := a
        j := m
        for i < j {
            h := int(uint(i+j) >> 1)
            if !data.Less(m, h) {
                i = h + 1
            } else {
                j = h
            }
        }
        // Swap values until data[m] reaches the position i.
        for k := m; k > i; k-- {
            data.Swap(k, k-1)
        }
        return
    }

    // 看起來是要計算出一箇中位數, 找出 a 到 b 之前的中間點
    mid := int(uint(a+b) >> 1)
    // 根據傳入引數的是 symMerge(data, a, a+blockSize, b)
    // 這個 m 不就是中間點嗎 ? 有點疑惑
    // mid + m , 中間點加上中間點, 就是 = b ? 繼續往下看吧
    n := mid + m
    var start, r int
    // 這裡做判斷了, 如果 m > mid , 就是說真正的中心點不是傳入進來的
    // 應該是左邊一半比右邊一半的元素要多
    if m > mid {
        // 這 start 此時應該是 m 到 b 之前的位置
        start = n - b
        r = mid
    } else { // 正好數中間點, 或者 mid 在右邊部分
        start = a // 左邊元素的起點
        r = m // 左邊元素結束的位置
    }
    p := n - 1 // 真正的最後一個元素

    // 能進入這個迴圈應該只能是上面 else 的情況下
    // 然後最終還是要叫 start 不小於 r, 和上面 if 中的情況類似
    for start < r {
        // 再求一下中位數
        c := int(uint(start+r) >> 1)
        if !data.Less(p-c, c) {
            start = c + 1
        } else {
            r = c
        }
    }

    // 燒腦了, 到這裡分析不下去了, 過段時間再看了, 放一放, 看的多了就暈了
    // 如果有誰能看明白告訴下我也可以

    end := n - start
    if start < m && m < end {
        rotate(data, start, m, end)
    }

    if a < start && start < mid {
        symMerge(data, a, start, mid)
    }
    if mid < end && end < b {
        symMerge(data, mid, end, b)
    }
}

最後沒有分析完, 先總結下吧, 就是當從 Stable 函式進入的時候, 就會使用歸併排序進行排序,
中間還會使用插入排序預處理下資料, 因為只呼叫了插入和歸併, 所以它是穩定排序。

應用

內部實現的排序

sort 包本身完成了 int float64 和 string 型別的資料排序, 使用起來也很簡單, 分別呼叫:
sort.Ints()sort.Stringssort.Float64s 即可。

它們的實現也很簡單, 分別維護了一個 IntSliceFloat64SliceStringSlice 的結構,
並且實現了 Interface 介面的 Len 、Less、 Swap 方法, 這裡把 IntSlice 的實現複製出來看下,
其他的可以直接檢視原始碼, 非常簡單

type IntSlice []int

func (p IntSlice) Len() int           { return len(p) }
func (p IntSlice) Less(i, j int) bool { return p[i] < p[j] }
func (p IntSlice) Swap(i, j int)      { p[i], p[j] = p[j], p[i]

自定義排序實現

自己實現 sort 的介面也很簡單, 直接實現了介面要求的三個方法即可。這裡沾了一段我部落格程式碼裡的
片段
可以作為參考, 也是比較簡單的。

package mdfile

// Tags 標籤
type Tags []Tag

// Tag 標籤
type Tag struct {
    // 標籤名稱
    Title string

    // 標籤下文章的數量
    Number int

    // 是否是選中的
    Active bool
}

// Len 實現 Sort 的介面
func (tags Tags) Len() int {
    return len(tags)
}

// Swap 實現的 Sort 介面
func (tags Tags) Swap(i, j int) {
    tags[i], tags[j] = tags[j], tags[i]
}

// Less 實現的 Sort 介面, 按照標籤數量排序
func (tags Tags) Less(i, j int) bool {
    return tags[i].Number > tags[j].Number
}

Reverse 函式

這個函式很巧妙, 比如 go 預設實現的 3 種基礎型別, 它預設是從小到大, 如果想要從大到小排序,
實現也很方便, 用 Reverse 函式包裝一個即可, 如:

sort.Sort(sort.Reverse(sort.IntSlice([]int{1,2,3,4,5,6})))

這裡因為要被 Reverse 包裝, 索引只能使用原始的資料結構 IntSlice ,
其實 Ints 也只是呼叫了一下 Sort(IntSlice(...)) 而已, 下面看看程式碼實現:

type reverse struct {
    // This embedded Interface permits Reverse to use the methods of
    // another Interface implementation.
    Interface
}

// Less returns the opposite of the embedded implementation's Less method.
func (r reverse) Less(i, j int) bool {
    return r.Interface.Less(j, i)
}

// Reverse returns the reverse order for data.
func Reverse(data Interface) Interface {
    return &reverse{data}
}

可以看到這個函式非常簡單, 直接將我們自己的結構包裝起來, 並且也只是實現了一個 Less 方法,
這個方法也非常簡單,只是將我們原本的條件中的 i 和 j 互換了一下, 就完成了反序。

完結

看了幾個小時才把程式碼看完, 還有一些地方是模糊的, 這個程式碼有點燒腦, 等有空再完善了,
不過這個 sort 寫的真的很巧妙, 值得學習一下。

本文來自:
broqiang.com
可以隨意轉載, 帶上我就行。

相關文章