Go的WaitGroup原始碼分析

鹿呦呦發表於2022-01-09

WaitGroup 是開發中經常用到的併發控制手段,其原始碼在 src/sync/waitgroup.go 檔案中,定義了 1 個結構體和 4 個方法:

  • WaitGroup{}:結構體。
  • state():內部方法,在 Add()Wait() 中呼叫。
  • Add():新增任務數。
  • Done():完成任務,其實就是 Add(-1)
  • Wait():阻塞等待所有任務的完成。

以下原始碼基於 Go 1.17.5 版本,有刪減。

$ go version
go version go1.17.5 darwin/amd64

在學習之前可以先了解一些概念:

  • 結構體對齊相關的內容,可參考之前的筆記
  • 訊號量函式有兩個:
    runtime_Semacquire 表示增加一個訊號量,並掛起當前 goroutine。在 Wait() 裡用到。
    runtime_Semrelease 表示減少一個訊號量,並喚醒 sema 上其中一個正在等待的 goroutine。在 Add() 裡用到。
  • unsafe.Pointer 用於各種指標相互轉換;
    uintptrgolang 的內建型別,能儲存指標的整型,其底層型別是 int,可以和 unsafe.Pointer 相互轉換。

一、結構體

1.1 state1 陣列的組成

type WaitGroup struct {
    // 表示 `WaitGroup` 是不可複製的,只能用指標傳遞,保證全域性唯一。
    noCopy noCopy
    // state1 = state(*unit64) + sema(*unit32)
    // state = counter + waiter
    state1 [3]uint32
}

state1 是一個 uint32 陣列,包含了counter 總數、waiter 等待數 和 sema 訊號量,其中:

  • counter:通過 Add() 設定的子 goroutine 的計數值。
  • waiter:通過 Wait() 陷入阻塞的 waiter 數。
  • sema:訊號量。

1.2 state 和 sema 的位置

實際上,counterwaiter 合在一起,當成一個 64 位的整數來使用,所以 state1 陣列又可以看成由 *unit64state*unit32sema 組成,即:

state1 = state + sema,
其中 state = counter + waiter。

32 位系統下4位元組對齊,64位系統下8位元組對齊,下面的內部方法 state() 有進行判斷。

state() 方法將 state1 陣列中儲存的狀態取出來,返回值 statep 就是計數器的狀態,也就是 counterwaiter 的整體,semap 是訊號量。

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    // 判斷是否64位對齊
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    } else {
        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    }
}

state() 中,根據執行時分配的地址轉化成 uintptr 後,再 %8,判斷結果是否等於0,若為 0, 則說明分配的地址是 64 位對齊。

  • 如果是 64 位對齊,則陣列前兩位是 state,後一位是 sema
  • 如果不是64位對齊,則前面一位是 sema(32位) 後面兩位是 state
對齊方式 state[0] state[1] state[2]
64位 waiter counter sema
32位 sema waiter counter

當我們初始化一個 waitGroup 物件時,其 counter 值、waiter 值、sema 值均為 0。

1.3 為什麼這麼設計 state1 陣列

為什麼要把 counterwaiter 當成一個整體來設計?這是因為對 state 使用了 atomic.64 的操作,如:

  • Add()
state := atomic.AddUint64(statep, uint64(delta)<<32)
  • Wait()
state := atomic.LoadUint64(statep)
if atomic.CompareAndSwapUint64(statep, state, state+1) {}

要保證 state 的 64 位的原子性,就要保證資料是一次讀入記憶體的,而要保證這種一次性,就要保證 state 是 64 位對齊的。

二、Add()函式

利用64位的原子加,給 counterdeltadelta可能為負),當 counter 變零,通過訊號量喚醒等待的 goroutine。這裡將 Add() 分成幾步來分析:

  • step 1:獲取 counterwaiter、和 sema 對應的指標,並將 delta 加到 counter 上。
// 獲取statep、semap 的指標,也就是counter、waiter和sema
statep, semap := wg.state() 
// 把delta左移32位累加到state,也就是把等待的couter利用原子加,加上delta
state := atomic.AddUint64(statep, uint64(delta)<<32)
v := int32(state >> 32) // 低32位是couter,也就是增加的,注意,這裡轉換成了int32型別
w := uint32(state)      // 高32位是waiter
  • step 2counter 不允許為負數,否則報 panic
if v < 0 {
    panic("sync: negative WaitGroup counter")
}

counter 是活躍的 goroutine 數量,肯定大於 0,如果它為負數,有兩種情況:

第一種是 Add() 的時候,delta 直接就是負數,進行原子加操作後,counter 就小於0,我們一般不這麼寫;
第二種是執行 Done(),也就是 Add(-1) 的時候,前一個 goroutine 減到了 0,還沒執行完,被掛起了,又來了一個 Done(),邏輯就出錯了。

  • step 3:已經執行了 Wait,此時不允許 Add
if w != 0 && delta > 0 && v == int32(delta) {
    panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}

waiter 是等待的 goroutine 數量,只有加 1 和置零兩種操作,所以肯定大於等於 0,第一次 Add(n) 的時候,counter=n,waiter=0w != 0 說明已經執行了 Wait

delta > 0 說明這是一次加的操作。如果 v == int32(delta) 也就是 v + delta == delta,推匯出 v=0,那就可能是第一次 Add() 或者是執行 Add(-1) v 減到了 0,即先 WaitAdd 了。

  • step 4counter > 0waiter = 0,直接返回。
if v > 0 || w == 0 {
    return
}

經過累加後,此時,counter >= 0

如果 counter 為正,說明不需要釋放訊號量,直接退出;
如果 waiter 為 0,說明沒有等待者,也不需要釋放訊號量,直接退出。

  • step 5:檢查 WaitGroup 是否被濫用,即 Add 不能與 Wait 併發呼叫。
if *statep != state {
    panic("sync: WaitGroup misuse: Add called concurrently with Wait")    
}

執行到這裡,counter=0 && waiter>0,說明之前的 Done 已經完成了,計數器清零,該釋放訊號,喚醒所有在 waitgoroutine 了。如果這時候 state 狀態發生變化,則說明當前有人修改過,進行了 Add 操作,報 panic

這一步的判斷相當於鎖,保證 WaitGroup 沒有被濫用。

  • step 6:釋放所有排隊的 waiter
*statep = 0
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false, 0)
}

如果執行到這裡,一定是負 delta 的操作,counter=0,waiter>0 說明已經完成任務,沒有活躍的 goroutine 了,需要釋放訊號量。將狀態全部歸 0,並釋放所有阻塞的 waiter

三、Wait()函式

執行 Wait() 函式的主 goroutine 會將 waiter 值加 1,並阻塞等待該值為 0,才能繼續執行後續程式碼。

func (wg *WaitGroup) Wait() {
    // 獲取statep、semap 的指標,也就是counter、waiter和sema
    statep, semap := wg.state()
    
    for {// 注意這裡在死迴圈中 
        state := atomic.LoadUint64(statep)// 原子操作
        v := int32(state >> 32) // couter
        w := uint32(state)      // waiter
        
        // counter為0,說明所有的goroutine都退出了,不需要等待
        if v == 0 {
            return
        }
        
        // CAS操作增加waiter
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            // 一旦訊號量sema大於0,就掛起當前goroutine
            runtime_Semacquire(semap)
            
            // Add()函式,觸發訊號量前,會將counter和waiter置為0,所以此時*statep一定為0。如果*statep不為0,說明還未等Waiter執行完Wait(),又執行了Add()或Wait()操作了,WaitGroup發生了複用。
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            return
        }
    }
}

四、競爭分析

Add()Wait() 中,對 state 資料的操作都存在資料競爭:

Add() delta 加到 counter 最後訊號釋放的時候,需要讀 waitersema
Wait() CAS 操作,給 waiter 加 1,增加 sema 訊號量 counter,為 0 直接返回

解決資料競爭,可以通過加鎖來實現,操作前給 state1 陣列加鎖,結束後釋放鎖,這樣肯定沒有安全性的問題但是低效。

原始碼裡解決資料競爭,沒有使用鎖,它分了幾種情況來解決:

  • AddAdd 併發

多個 Add 同時加,只加數,不管是加正數還是加負數,只要加過之後 counter 大於0,就直接 return。因為是原子加,總有先後順序,保證了不會加丟。

if v > 0 || w == 0 {
    return
}

如果加負數之後 counter 等於0,這個時候要進行訊號的釋放操作,不能允許其他的 Add 同時改這個資料了。

if w != 0 && delta > 0 && v == int32(delta) {
    panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
  • AddWait 併發
    如果 Add 加負數之後 counter 等於0,這個時候要進行訊號的釋放操作,不允許 Wait 去修改這個資料。如果 Wait 先讀出了 state 又改了 state,就會 panic
if *statep != state {
    panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}

五、例項分析

func main() {
    var wg sync.WaitGroup...............①

    wg.Add(2)...........................②
    
    go func() { 
        fmt.Println(1)
        wg.Done().......................③
    }()

	go func() {
        fmt.Println(2)
        wg.Done().......................④
    }()
	
    wg.Wait()...........................⑤
    
	fmt.Println("all work done!")
}

執行完 1,2 後,3、4、5 隨機執行。

  • 假設按【1、2、3、4、5】的順序執行,counterwaiter 的數值變化如下:
    ① counter=0,waiter=0 //初始化的預設值為0
    ② counter=2,waiter=0 //原子加操作,給counter加2
    ③ counter=1,waiter=0 //完成一個Done,給counter減1,counter從2變成1
    ④ counter=0,waiter=0 //又完成一個Done,給counter減1,counter變成0,因為滿足v>0或w=0,直接return了,不用發訊號
    ⑤ counter=0,waiter=0 //因為v=0,所以直接return,不用CAS操作

  • 假設按【1、2、5、3、4】的順序執行,counterwaiter 的數值變化如下:
    ① counter=0,waiter=0 //初始化的預設值為0
    ② counter=2,waiter=0 //原子加操作,給counter加2
    ⑤ counter=2,waiter=1 //CAS給waiter加1,所以waiter由0變成2
    ③ counter=1,waiter=1 完成一個Done,給counter減1,counter從2變成1
    ④ counter=0,waiter=1 又完成一個Done,給counter減1,counter變成0,發訊號,通知waiter不再阻塞,main繼續執行

相關文章