go中waitGroup原始碼解讀

Rick.lz發表於2021-03-07

waitGroup原始碼刨銑

前言

學習下waitGroup的實現

本文是在go version go1.13.15 darwin/amd64上進行的

WaitGroup實現

看一個小demo

func waitGroup() {
	var wg sync.WaitGroup

	wg.Add(4)
	go func() {
		defer wg.Done()
		fmt.Println(1)
	}()

	go func() {
		defer wg.Done()
		fmt.Println(2)
	}()

	go func() {
		defer wg.Done()
		fmt.Println(3)
	}()

	go func() {
		defer wg.Done()
		fmt.Println(4)
	}()

	wg.Wait()
	fmt.Println("1 2 3 4 end")
}

1、啟動goroutine前將計數器通過Add(4)將計數器設定為待啟動的goroutine個數。

2、啟動goroutine後,使用Wait()方法阻塞主協程,等待計數器變為0。

3、每個goroutine執行結束通過Done()方法將計數器減1。

4、計數器變為0後,阻塞的goroutine被喚醒。

看下具體的實現

// WaitGroup 不能被copy
type WaitGroup struct {
	noCopy noCopy

	state1 [3]uint32
}

noCopy

意思就是不讓copy,是如何實現的呢?

Go中沒有原生的禁止拷貝的方式,所以如果有的結構體,你希望使用者無法拷貝,只能指標傳遞保證全域性唯一的話,可以這麼幹,定義 一個結構體叫 noCopy,實現如下的介面,然後嵌入到你想要禁止拷貝的結構體中,這樣go vet就能檢測出來。

// noCopy may be embedded into structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
type noCopy struct{}

// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock()   {}
func (*noCopy) Unlock() {}

測試下

type noCopy struct{}

func (*noCopy) Lock()   {}
func (*noCopy) Unlock() {}

type Person struct {
	noCopy noCopy
	name   string
}

// go中的函式傳參都是值拷貝
func test(person Person) {
	fmt.Println(person)
}

func main() {
	var person Person
	test(person)
}

go vet main.go

$ go vet main.go
# command-line-arguments
./main.go:18:18: test passes lock by value: command-line-arguments.Person contains command-line-arguments.noCopy
./main.go:19:14: call of fmt.Println copies lock value: command-line-arguments.Person contains command-line-arguments.noCopy
./main.go:24:7: call of test copies lock value: command-line-arguments.Person contains command-line-arguments.noCopy

使用vet檢測到了不能copy的錯誤

state1

	// 64 位值: 高 32 位用於計數,低 32 位用於等待計數
	// 64 位的原子操作要求 64 位對齊,但 32 位編譯器無法保證這個要求
	// 因此分配 12 位元組然後將他們對齊,其中 8 位元組作為狀態,其他 4 位元組用於儲存原語
	state1 [3]uint32

這點是wait_group很巧妙的一點,大神寫程式碼的思路就是驚奇

這個設計很奇妙,通過記憶體對齊來處理wait_group中的waiter數、計數值、訊號量。什麼是記憶體對齊可參考什麼是記憶體對齊,go中記憶體對齊分析

來分析下state1是如何記憶體對齊來處理幾個計數值的儲存

計算機為了加快記憶體的訪問速度,會對記憶體進行對齊處理,CPU把記憶體當作一塊一塊的,塊的大小可以是2、4、8、16位元組大小,因此CPU讀取記憶體是一塊一塊讀取的。

合理的記憶體對齊可以提高記憶體讀寫的效能,並且便於實現變數操作的原子性。

在不同平臺上的編譯器都有自己預設的 “對齊係數”,可通過預編譯命令#pragma pack(n)進行變更,n就是代指 “對齊係數”。

一般來講,我們常用的平臺的係數如下:

  • 32 位:4

  • 64 位:8

state1這塊就相容了兩種平臺的對齊係數

對於64未系統來講。記憶體訪問的步長是8。也就是cpu一次訪問8位偏移量的記憶體空間。當時對於32未的系統,記憶體的對齊係數是4,也就是訪問的步長是4個偏移量。

所以為了相容這兩種模式,這裡採用了uint32結構的陣列,保證在不同型別的機器上都是12個位元組,一個uint32是4位元組。這樣對於32位的4步長訪問是沒有問題了,64位的好像也沒有解決,8步長的訪問會一次讀入兩個uint32的長度。

所以,下面的讀取也進行了操作,將兩個uint32的記憶體放到一個uint64中返回,這樣就同時解決了32位和64位訪問步長的問題。

所以,64位系統和32位系統,state1counter,waiter,semaphore的記憶體佈局是不一樣的。

state [0] state [1] state [2]
64位 waiter counter semaphore
32位 semaphore waiter counter

counter位於高地址位,waiter位於地址位

waitgroup

下面是state的程式碼

func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	// 判斷是否是64位
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}

對於count和wait在高低地址位的體現,在add中的程式碼可體現

        // // 將 delta 加到 statep 的前 32 位上,即加到計數器上
	state := atomic.AddUint64(statep, uint64(delta)<<32)
	v := int32(state >> 32) // count
	w := uint32(state)  // wait

通過補碼的移位來看下分析下

statep               0000 0000 0000 0000 0000 0000 0000 0001 0000 0000 0000 0000 0000 0000 0000 0000  

int64(1)             0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0001

int64(1)<<32:       0000 0000 0000 0000 0000 0000 0000 0001 0000 0000 0000 0000 0000 0000 0000 0000  

statep+int64(1)<<32  

state:              0000 0000 0000 0000 0000 0000 0000 0001 0000 0000 0000 0000 0000 0000 0000 0000  

int32(state >> 32)   0000 0000 0000 0000 0000 0000 0000 0001

uint32(state):      0000 0000 0000 0000 0000 0000 0000 0000

再來看下-1操作

statep               0000 0000 0000 0000 0000 0000 0000 0001 0000 0000 0000 0000 0000 0000 0000 0000  

int64(-1)            1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111

int64(-11)<<32:     1111 1111 1111 1111 1111 1111 1111 1111 0000 0000 0000 0000 0000 0000 0000 0000  

statep+int64(-11)<<32  

state:              0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000  

int32(state >> 32)   0000 0000 0000 0000 0000 0000 0000 0000

uint32(state):      0000 0000 0000 0000 0000 0000 0000 0000

關於補碼的操作可參考原碼, 反碼, 補碼 詳解

訊號量(semaphore)

訊號量是Unix系統提供的一種保護共享資源的機制,用於防止多個執行緒同時訪問某個資源。

可簡單理解為訊號量為一個數值:

  • 當訊號量>0時,表示資源可用,獲取訊號量時系統自動將訊號量減1;

  • 當訊號量==0時,表示資源暫不可用,獲取訊號量時,當前執行緒會進入睡眠,當訊號量為正時被喚醒。

WaitGroup中的實現就用到了這個,在下面的程式碼實現就能看到

Add(Done)

// Add將增量(可能為負)新增到WaitGroup計數器中。
// 如果計數器為零,則釋放等待時阻塞的所有goroutine。
// 如果計數器變為負數,請新增恐慌。
//
// 請注意,當計數器為 0 時發生的帶有正的 delta 的呼叫必須在 Wait 之前。
// 當計數器大於 0 時,帶有負 delta 的呼叫或帶有正 delta 呼叫可能在任何時候發生。
// 通常,這意味著對Add的呼叫應在語句之前執行建立要等待的goroutine或其他事件。
// 如果將WaitGroup重用於等待幾個獨立的事件集,新的Add呼叫必須在所有先前的Wait呼叫返回之後發生。
func (wg *WaitGroup) Add(delta int) {
	// 獲取counter,waiter,以及semaphore對應的指標
	statep, semap := wg.state()
	...
	// 將 delta 加到 statep 的前 32 位上,即加到計數器上
	state := atomic.AddUint64(statep, uint64(delta)<<32)
	// 高地址位counter
	v := int32(state >> 32)
	// 低地址為waiter
	w := uint32(state)
	...
	// 計數器不允許為負數
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
	// wait不等於0說明已經執行了Wait,此時不容許Add
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// 計數器的值大於或者沒有waiter在等待,直接返回
	if v > 0 || w == 0 {
		return
	}
	// 執行到這裡只有一種情況 v == 0 && w != 0

	// 這時 Goroutine 已經將計數器清零,且等待器大於零(併發呼叫導致)
	// 這時不允許出現併發使用導致的狀態突變,否則就應該 panic
	// - Add 不能與 Wait 併發呼叫
	// - Wait 在計數器已經歸零的情況下,不能再繼續增加等待器了
	// 仍然檢查來保證 WaitGroup 不會被濫用

	// 這一點很重要,這段程式碼同時也保證了這是最後的一個需要等待阻塞的goroutine
	// 然後在下面通過runtime_Semrelease,喚醒被訊號量semap阻塞的waiter
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// 結束後將等待器清零
	*statep = 0
	for ; w != 0; w-- {
		// 釋放訊號量,通過runtime_Semacquire喚醒被阻塞的waiter
		runtime_Semrelease(semap, false, 0)
	}
}

梳理下流程

1、首先獲取儲存在state1中對應的幾個變數的指標;

2、counter儲存在高位,增加的時候需要左移32位;

3、counter的數量不能小於0,小於0丟擲panic;

4、同樣也會判斷,已經的執行wait之後,不能在增加counter;

5、(這點很重要,我自己看了好久才明白)計數器的值大於或者沒有waiter在等待,直接返回

	// 計數器的值大於或者沒有waiter在等待,直接返回
	if v > 0 || w == 0 {
		return
	}

因為waiter的值只會被執行一次+1操作,所以這段程式碼保證了只有在v == 0 && w != 0,也就是最後一個Done()操作的時候,走到下面的程式碼,釋放訊號量,喚醒被訊號量阻塞的Wait(),結束整個WaitGroup

Done()也是呼叫了Add

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

Wait

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
	// 獲取counter,waiter,以及semaphore對應的指標
	statep, semap := wg.state()
	...
	for {
		// 獲取對應的counter和waiter數量
		state := atomic.LoadUint64(statep)
		v := int32(state >> 32)
		w := uint32(state)
		// Counter為0,不需要等待
		if v == 0 {
			if race.Enabled {
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// 原子(cas)增加waiter的數量(只會被+1操作一次)
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
			...
			// 這塊用到了,我們上文講的那個訊號量
			// 等待被runtime_Semrelease釋放的訊號量喚醒
			// 如果 *semap > 0 則會減 1,等於0則被阻塞
			runtime_Semacquire(semap)

			// 在這種情況下,如果 *statep 不等於 0 ,則說明使用失誤,直接 panic
			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			...
			return
		}
	}
}

梳理下流程

1、首先獲取儲存在state1中對應的幾個變數的指標;

2、一個for迴圈,來阻塞等待所有的goroutine退出;

3、如果counter為0,不需要等待,直接退出即可;

4、原子(cas)增加waiter的數量(只會被+1操作一次);

5、整個Wait()會被runtime_Semacquire阻塞,直到等到退出的訊號量;

6、Done()會在最後一次的時候通過runtime_Semrelease發出取消阻塞的訊號,然後被runtime_Semacquire阻塞的Wait()就可以退出了;

7、整個WaitGroup執行成功。

waitgroup

總結

程式碼中我感到設計比較巧妙的有兩個部分:

1、state1的處理,保證記憶體對齊,設定高低位記憶體來儲存不同的值,同時32位和64位平臺的處理方式還不同;

2、訊號量的阻塞退出,這塊最後一個Done退出的時候,才會觸發阻塞訊號量,退出Wait(),然後結束整個waitGroup。再此之前,當Wait()在成功將waiter變數+1操作之後,就會被runtime_Semacquire阻塞,直到最後一個Done,訊號的發出。

對於WaitGroup的使用

1、計數器的值不能為負數,可能是Add(-1)觸發的,也可能是Done()觸發的,否則會panic;

2、Add數量的新增,要發生在Wait()之前;

3、WaitGroup是可以重用的,但是需要等上一批的goroutine 都呼叫Wait完畢後才能繼續重用WaitGroup

參考

【《Go專家程式設計》Go WaitGroup實現原理】https://my.oschina.net/renhc/blog/2249061
【Go中由WaitGroup引發對記憶體對齊思考】https://cloud.tencent.com/developer/article/1776930
【Golang 之 WaitGroup 原始碼解析】https://www.linkinstar.wiki/2020/03/15/golang/source-code/sync-waitgroup-source-code/
【sync.WaitGroup】https://golang.design/under-the-hood/zh-cn/part1basic/ch05sync/waitgroup/

相關文章