Golang WaitGroup 底層原理及原始碼詳解

程式設計師小杜發表於2023-04-27

0 知識背景

在進入正文前,先對 WaitGroup 及其相關背景知識做個簡單的介紹,這裡主要是 WaitGroup 的基本使用,以及系統訊號量的基礎知識。對這些比較熟悉的小夥伴可以直接跳過這一節。

0.1 WaitGroup

WaitGroup 是 Golang 中最常見的併發控制技術之一,它的作用我們可以簡單類比為其他語言中多執行緒併發控制中的 join(),例項程式碼如下:

package main

import (
    "fmt"
    "sync"
    "time"
)

func main() {
    fmt.Println("Main starts...")
    var wg sync.WaitGroup

    // 2 指的是下面有兩個協程需要等待
    wg.Add(2)

    go waitFunc(&wg, 3)
    go waitFunc(&wg, 1)

    // 阻塞等待
    wg.Wait()

    fmt.Println("Main ends...")
}

func waitFunc(wg *sync.WaitGroup, num int) {
    // 函式結束時告知 WaitGroup 自己已經結束
    defer wg.Done()
    time.Sleep(time.Duration(num) * time.Second)
    fmt.Printf("Hello World from %v\n", num)
}

// 結果輸出:
Main starts...
Hello World from 1
Hello World from 3
Main ends...

如果這裡沒有 WaitGroup,主協程(main 函式)會直接跑到最後的 Main ends...,而沒有中間兩個 goroutine 的輸出,加了 WaitGroup 後,main 就會在 wg.Wait() 處阻塞等待兩個協程都結束後才繼續執行。

上面我們看到的 WaitGroup 的三個方法:Wait()Add(int)Done() 也是 WaitGroup 物件僅有的三個方法。

0.2 訊號量(Semaphore)

訊號量(Semaphore)是一種用於實現多程式或多執行緒之間同步和互斥的機制,也是 WaitGroup 中所採用的技術。並且 WaitGroup 自身的同步原理,也與訊號量很相似。

由於翻譯問題,不熟悉的小夥伴經常將訊號量(Semaphore)和訊號(Signal)搞混,這倆實際上是兩個完全不同的東西。Semaphore 在英文中的本意是旗語,也就是航海領域的那個旗語,利用手旗或旗幟傳遞訊號的溝通方式。在計算機領域,Semaphore,即訊號量,在廣義上也可以理解為一種程式、執行緒間的通訊方式,但它的主要作用,正如前面所說,是用於實現程式、執行緒間的同步和互斥。

訊號量本質上可以簡單理解為一個整型數,主要包含兩種操作:P(Proberen,測試)操作和 V(Verhogen,增加)操作。其中,P 操作會嘗試獲取一個訊號量,如果訊號量的值大於 0,則將訊號量的值減 1 並繼續執行;否則,當前程式或執行緒就會被阻塞,直到有其他程式或執行緒釋放這個訊號量為止。V 操作則是釋放一個訊號量,將訊號量的值加 1。

可以把訊號量看作是一種類似鎖的東西,P 操作相當於獲取鎖,而 V 操作相當於釋放鎖。由於訊號量是一種作業系統級別的機制,通常由核心提供支援,因此我們不用擔心上述對訊號量的操作本身會產生競態條件,相信核心能搞定這種東西。

本文的重點不是訊號量,因此不會過多展開關於訊號量的技術細節,有興趣的小夥伴可以查閱相關資料。

最後提一嘴技術之外的東西,Proberen 和 Verhogen 這倆單詞眼生吧?因為它們是荷蘭語,不是英語。為啥是荷蘭語嘞?因為發明訊號量的人,是上古計算機大神,來自荷蘭的計算機先驅 Edsger W. Dijkstra 先生。嗯,對,就是那個 Dijkstra。

1 WaitGroup 底層原理

宣告:本文所用原始碼均基於 Go 1.20.3 版本,不同版本 Go 的 WaitGroup 原始碼可能略有不同,但設計思想基本是一致的。

WaitGroup 相關原始碼非常短,加上註釋和空行也只有 120 多行,它們全都在 src/sync/waitgroup.go 中。

1.1 定義

先來看 WaitGroup 的定義,這裡我把原始檔中的註釋都簡單翻譯了一下:

// WaitGroup 等待一組 Goroutine 完成。
// 主 Goroutine 呼叫 Add 方法設定要等待的 Goroutine 數量,
// 然後每個 Goroutine 執行並在完成後呼叫 Done 方法。
// 同時,可以使用 Wait 方法阻塞,直到所有 Goroutine 完成。
//
// WaitGroup 在第一次使用後不能被複制。
//
// 根據 Go 記憶體模型的術語,Done 呼叫“同步於”任何它解除阻塞的 Wait 呼叫的返回。
type WaitGroup struct {
    noCopy noCopy

    state atomic.Uint64 // 高 32 位是計數器, 低 32 位是等待者數量(後文解釋)。
    sema  uint32
}

WaitGroup 型別是一個結構體,它有三個私有成員,我們一個一個來看。

1.1.1 noCopy

首先是 noCopy,這個東西是為了告訴編譯器,WaitGroup 結構體物件不可複製,即 wg2 := wg 是非法的。之所以禁止複製,是為了防止可能發生的死鎖。但實際上如果我們對 WaitGroup 物件進行復制後,至少在 1.20 版本下,Go 的編譯器只是發出警告,沒有阻止編譯過程,我們依然可以編譯成功。警告的內容如下:

assignment copies lock value to wg2: sync.WaitGroup contains sync.noCopy

為什麼編譯器沒有編譯失敗,我猜應該是 Go 官方想盡量減少編譯器對程式的干預,而更多地交給程式設計師自己去處理(此時 Rust 發出了一陣笑聲)。總之,我們在使用 WaitGroup 的過程中,不要去複製它就對了,不然非常容易產生死鎖(其實結構體註釋上也說了,WaitGroup 在第一次使用後不能被複制)。譬如我將文章開頭程式碼中的 main 函式稍微改了改:

func main() {
    fmt.Println("Main starts...")
    var wg sync.WaitGroup

    // 2 指的是下面有兩個協程需要等待
    wg.Add(1)
    wg2 := wg
    wg2.Add(1)

    go waitFunc(&wg, 3)
    go waitFunc(&wg2, 1)

    // 阻塞等待
    wg.Wait()
    wg2.Wait()

    fmt.Println("Main ends...")
}

// 輸出結果
Main starts...
Hello World from 1
Hello World from 3
fatal error: all goroutines are asleep - deadlock!

goroutine 1 [semacquire]:
sync.runtime_Semacquire(0xc000042060?)
        C:/Program Files/Go/src/runtime/sema.go:62 +0x27
sync.(*WaitGroup).Wait(0xe76b28?)
        C:/Program Files/Go/src/sync/waitgroup.go:116 +0x4b
main.main()
        D:/Codes/Golang/waitgroup/main.go:23 +0x139
exit status 2

為什麼會這樣?因為 wg 已經 Add(1) 了,這時我們複製了 wg 給 wg2,並且是個淺複製,意味著 wg2 內實際上已經是 Add(1) 後的狀態了(state 成員儲存的狀態,即它的值),此時我們再執行 wg2.Add(1),其實相當於執行了兩次 wg2.Add(1)。而後面 waitFunc() 中對 wg2 只進行了一次 Done() 釋放操作,main 函式在 wg2.Wait() 時就陷入了無限等待,即 all goroutines are asleep。等看了後面 Add()Done() 的原理後,再回頭來看這段死鎖的程式碼,會更加清晰。

那麼這段程式碼能既複製,又不死鎖嗎?當然可以,只需要把 wg2 := wg 提到 wg.Add(1) 前面即可。

1.1.2 state atomic.Uint64

stateWaitGroup 的核心,它是一個無符號的 64 位整型,並且用的是 atomic 包中的 Uint64,所以 state 本身是執行緒安全的。至於 atomic.Uint64 為什麼能保證執行緒安全,因為它使用了 CompareAndSwap(CAS) 操作,而這個操作依賴於 CPU 提供的原子性指令,是 CPU 級的原子操作。

state 的高 32 位是計數器(counter),低 32 位是等待者數量(waiters)。其中計數器其實就是 Add(int) 數量的總和,譬如 Add(1) 後再 Add(2),那麼這個計數器就是 1 + 2 = 3;而等待數量就是現在有多少 goroutine 在執行 Wait() 等待 WaitGroup 被釋放。

1.1.3 sema uint32

這玩意兒就是訊號量,它的用法我們到後文結合程式碼再講。

1.2 Add(delta int)

首先是 Add(delta int) 方法。WaitGroup 所有三個方法都沒有返回值,並且只有 Add 擁有引數,整個設計可謂簡潔到了極點。

Add 方法的第一句程式碼是:

if race.Enabled {
    if delta < 0 {
        // Synchronize decrements with Wait.
        race.ReleaseMerge(unsafe.Pointer(wg))
    }
    race.Disable()
    defer race.Enable()
}

race.Enabled 是判斷當前程式是否開啟了競態條件檢查,這個檢查是在編譯時需要我們手動指定的:go build -race main.go,預設情況下並不開啟,即 race.Enabled 在預設情況下就是 false。這段程式碼裡如果程式開啟了競態條件檢查,會將其關閉,最後再重新開啟。其他有關 race 的細節本文不再討論,這對我們理解 WaitGroup 也沒有太大影響,將其考慮進去反而會增加我們理解 WaitGroup 核心機制的複雜度,因此後續程式碼中也會忽略所有與 race 相關的部分。

Add 方法整理後的程式碼如下:

// Add 方法將 delta 值加上計數器,delta 可以為負數。如果計數器變為 0,
// 則所有在 Wait 上阻塞的 Goroutine 都會被釋放。
// 如果計數器變為負數,則 Add 方法會 panic。
//
// 注意:當計數器為 0 時呼叫 delta 值為正數的 Add 方法必須在 Wait 方法之前執行。
// 而 delta 值為負數或者 delta 值為正數但計數器大於 0 時,則可以在任何時間點執行。
// 通常情況下,這意味著應該在建立 Goroutine 或其他等待事件的語句之前執行 Add 方法。
// 如果一個 WaitGroup 用於等待多組獨立的事件,
// 那麼必須在所有先前的 Wait 呼叫返回之後再進行新的 Add 呼叫。
// 詳見 WaitGroup 示例程式碼。
func (wg *WaitGroup) Add(delta int) {
    // 將 int32 的 delta 變成 unint64 後左移 32 位再與 state 累加。
    // 相當於將 delta 與 state 的高 32 位累加。
    state := wg.state.Add(uint64(delta) << 32)
    // 高 32 位,就是 counter,計數器
    v := int32(state >> 32)
    // 低 32 位,就是 waiters,等待者數量
    w := uint32(state)
    // 計數器為負數時直接 panic
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    // 當 Wait 和 Add 併發執行時,會有機率觸發下面的 panic
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 如果計數器大於 0,或者沒有任何等待者,即沒有任何 goroutine 在 Wait(),那麼就直接返回
    if v > 0 || w == 0 {
        return
    }
    // 當 waiters > 0 時,這個 Goroutine 將計數器設定為 0。
    // 現在不可能有對狀態的併發修改:
    // - Add 方法不能與 Wait 方法同時執行,
    // - Wait 不會在看到計數器為 0 時增加等待者。
    // 仍然需要進行簡單的健全性檢查來檢測 WaitGroup 的誤用情況。
    if wg.state.Load() != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // 重置 state 為 0
    wg.state.Store(0)
    // 喚醒所有等待者
    for ; w != 0; w-- {
        // 使用訊號量控制喚醒等待者
        runtime_Semrelease(&wg.sema, false, 0)
    }
}

這裡我將原始碼中的註釋翻譯成了中文,並且自己在每句程式碼前也都加了註釋。

一開始,方法將引數 delta 變成 uint64 後左移 32 位,和 state 相加。因為 state 的高 32 位是這個 WaitGroup 的計數器,所以這裡其實就是把計數器進行了累加操作:

state := wg.state.Add(uint64(delta) << 32)

接著,程式會分別取出已經累加後的計數器 v,和當前的等待者數量 w

v := int32(state >> 32)
w := uint32(state)

然後是幾個判斷:

// 計數器為負數時直接 panic
if v < 0 {
    panic("sync: negative WaitGroup counter")
}
// 當 Wait 和 Add 併發執行時,會有機率觸發下面的 panic
if w != 0 && delta > 0 && v == int32(delta) {
    panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 如果計數器大於 0,或者沒有任何等待者,
// 即沒有任何 goroutine 在 Wait(),那麼就直接返回
if v > 0 || w == 0 {
    return
}

註釋已經比較清晰了,這裡主要展開解釋一下第二個 ifif w != 0 && delta > 0 && v == int32(delta)

  1. w != 0 意味著當前有 goroutine 在 Wait()
  2. delta > 0 意味著 Add() 傳入的是正整數,也就是正常呼叫;
  3. v == int32(delta) 意味著累加後的計數器等於傳入的 delta,這裡最容易想到的符合這個等式的場景是:原計數器等於 0 時,也就是 wg 第一次使用,或前面的 Wait() 已經全部結束時。

上述三個條件看上去有些衝突:w != 0 表示存在 Wait(),而 v == int32(delta) 按照分析應該不存在 Wait()。再往下分析,其實應該是 v 在獲取的時候不存在 Wait(),而 w 在獲取的時候存在 Wait()。會有這種可能嗎?會!就是併發的時候:當前 goroutine 獲取了 v,然後另一個 goroutine 立刻進行了 Wait(),接著本 goroutine 又獲取了 w,過程如下:

過程時序

我們可以用下面這段程式碼來複現這個 panic

func main() {
    var wg sync.WaitGroup

    // 併發問題不易復現,所以迴圈多次
    for i := 0; i < 100000; i++ {
        go addDoneFunc(&wg)
        go waitFunc(&wg)
    }

    wg.Wait()
}

func addDoneFunc(wg *sync.WaitGroup) {
    wg.Add(1)
    wg.Done()
}

func waitFunc(wg *sync.WaitGroup) {
    wg.Wait()
}

// 輸出結果
panic: sync: WaitGroup misuse: Add called concurrently with Wait

goroutine 71350 [running]:
sync.(*WaitGroup).Add(0x0?, 0xbf8aa5?)
        C:/Program Files/Go/src/sync/waitgroup.go:65 +0xce      
main.addDoneFunc(0xc1cf66?, 0x0?)
        D:/Codes/Golang/waitgroup/main.go:19 +0x1e
created by main.main
        D:/Codes/Golang/waitgroup/main.go:11 +0x8f
exit status 2

這段程式碼可能要多執行幾次才會看到上述效果,因為這種併發操作在整個 WaitGroup 的生命週期中會造成好幾種 panic,包括 Wait() 方法中的。

因此,我們在使用 WaitGroup 的時候應當注意一點:不要在被呼叫的 goroutine 內部使用 Add,而應當在外面使用,也就是:

// 正確
wg.Add(1)
go func(wg *sync.WaitGroup) {
    defer wg.Done()
}(&wg)
wg.Wait()

// 錯誤
go func(wg *sync.WaitGroup) {
    wg.Add(1)
    defer wg.Done()
}(&wg)
wg.Wait()

從而避免併發導致的異常。

上面三個 if 都結束後,會再次對 state 的一致性進行判斷,防止併發異常:

if wg.state.Load() != state {
    panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}

這裡 state.Load() 包括後面會出現的 Store() 都是 atomic.Uint64 的原子操作。

根據前面程式碼的邏輯,當程式執行到這裡時,計數器一定為 0,而等待者則可能 >= 0,於是程式碼會執行一次 wg.state.Store(0)state 設為 0,接著執行通知等待者結束等待的操作:

wg.state.Store(0)
for ; w != 0; w-- {
    runtime_Semrelease(&wg.sema, false, 0)
}

好了,這裡又是讓人迷惑的地方,我第一次看到這段程式碼時產生了下面幾個疑問:

  1. 為什麼 Add 方法會有計數器為 0 的分支邏輯?計數器不是累加的嗎?
  2. 為什麼要在 Add 中通知等待者結束,不應該是 Done 方法嗎?
  3. 那個 runtime_Semrelease(&wg.sema, false, 0) 為什麼需要迴圈 w 次?

一個一個來看。

  • 為什麼 Add 方法會有計數器為 0 的分支邏輯?

首先,按照前面程式碼的邏輯,只有計數器 v 為 0 的時候,程式碼才會走到最後兩句,而之所以為 0,是因為 Add(delta int) 的引數 delta 是一個 int,也就是說,delta 可以為負數!那什麼時候會傳入負數進來呢?Done 的時候。我們去看 Done() 的程式碼,會發現它非常簡單:

// Done 給 WaitGroup 的計數器減 1。
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

所以,Done 操作或是我們手動給 Add 傳入負數時,就會進入到 Add 最後幾行邏輯,而 Done 本身也意味著當前 goroutine 的 WaitGroup 結束,需要同步給外部的 Wait 讓它不再阻塞。

  • 為什麼要在 Add 中通知等待者結束,不應該是 Done 方法嗎?

嗯,這個問題其實在上一個問題已經一起解決了,因為 Done() 實際上呼叫了 Add(-1)

  • 那個 runtime_Semrelease(&wg.sema, false, 0) 為什麼需要迴圈 w 次?

這個函式按照字面意思,就是釋放訊號量。原始碼在 src/sync/runtime.go 中,函式宣告如下:

// Semrelease 函式用於原子地增加 *s 的值,
// 並在有等待 Semacquire 函式被阻塞的協程時通知它們繼續執行。
// 它旨在作為同步庫使用的簡單喚醒基元,不應直接使用。
// 如果 handoff 引數為 true,則將 count 直接傳遞給第一個等待者。
// skipframes 參數列示在跟蹤時要忽略的幀數,從 runtime_Semrelease 的呼叫者開始計數。
func runtime_Semrelease(s *uint32, handoff bool, skipframes int)

第一個引數就是訊號量的值本身,釋放時會 +1。

第二個引數 handoff 在我查閱了資料後,根據我的理解,應該是:當 handofffalse 時,僅正常喚醒其他等待的協程,但是不會立即排程被喚醒的協程;而當 handofftrue 時,會立刻排程被喚醒的協程。

第三個引數 skipframes,看上去應當也和排程有關,但具體含義我不太確定,這裡就不猜了(水平有限,見諒哈)。

按照訊號量本身的機制,這裡釋放時會 +1,同理還存在一個訊號量獲取函式 runtime_Semacquire(s *uint32) 會在訊號量 > 0 時將訊號量 -1,否則等待,它會在 Wait() 中被呼叫。這也是 runtime_Semrelease 需要迴圈 w 次的原因:因為那 wWait() 中會呼叫 runtime_Semacquire 並不斷將訊號量 -1,也就是減了 w 次,所以兩個地方需要對沖一下嘛。

訊號量和 WaitGroup 的機制很像,但計數器又是反的,所以這裡再多嘴補充幾句:

訊號量獲取時(runtime_Semacquire),其實就是在阻塞等待,P(Proberen,測試)操作,如果此時訊號量 > 0,則獲取成功,並將訊號量 -1,否則繼續等待;

訊號量釋放時(runtime_Semrelease),會把訊號量 +1,也就是 V(Verhogen,增加)操作。

1.2 Done()

Done() 方法我們在上面已經看到過了:

// Done 給 WaitGroup 的計數器減 1。
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

1.3 Wait()

同樣的,這裡我會把與 race 相關的程式碼都刪掉:

// Wait 會阻塞,直到計數器為 0。
func (wg *WaitGroup) Wait() {
    for {
        state := wg.state.Load()
        v := int32(state >> 32)  // 計數器
        w := uint32(state)       // 等待者數量
        if v == 0 {
            // 計數器為 0,直接返回。
            return
        }
        // 增加等待者數量
        if wg.state.CompareAndSwap(state, state+1) {
            // 獲取訊號量
            runtime_Semacquire(&wg.sema)
            // 這裡依然是為了防止併發問題
            if wg.state.Load() != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            return
        }
    }
}

Add 簡單多了,而且有了前面 Add 的長篇大論為基礎,Wait 的程式碼看上去一目瞭然。

當計數器為 0,即沒有任何 goroutine 呼叫 Add 時,直接呼叫 Wait,沒有任何意義,因此直接返回,也不操作訊號量。

最後 Wait 也有一個防止併發問題的判斷,而這個 panic 同樣可以用前面 Add 中的那段併發問題程式碼復現,大家可以試試。

Wait 中唯一不同的是,它用了一個無限迴圈 for{},為什麼?這是因為,wg.state.CompareAndSwap(state, state+1) 這個原子操作因為併發等原因有可能失敗,此時就需要重新獲取 state,把整個過程再走一遍。而一旦操作成功,Wait 會在 runtime_Semacquire(&wg.sema) 處阻塞,直到 Done 操作將計數器減為 0,Add 中釋放了訊號量。

2 結語

至此,WaitGroup 的原始碼已全部解析完畢。作為 Golang 中最重要的併發元件之一,WaitGroup 的原始碼居然只有這麼寥寥百行程式碼,倒是給我們理解它的原理降低了不少難度。

開文之前我也沒想到會寫這麼多東西,能看到這裡的小夥伴們,感謝你們的耐心。

本人水平有限,若文中有什麼紕漏或錯誤,還請大家不吝指出,再次感謝!

相關文章