一個golang並行庫原始碼解析

葛平發表於2017-04-07

場景

有這樣一種場景:四個任務A、B、C, D,其中任務B和C需要併發執行,得到結果1, 任務A執行得到結果2, 結果1和2作為任務D的引數傳入,然後執行任務D得到最終結果。我們可以將任務執行順序用如下圖示識:

jobA  jobB   jobC
            /
           /
         middle
          /
         /
     jobD

這是一個典型的多工併發場景,實際上隨著任務數量的增多,任務邏輯會更加複雜,如何編寫可維護健壯的邏輯程式碼變得十分重要,雖然golang提供了同步機制,但是需要寫很多重複無用的Add/Wait/Done程式碼,而且程式碼可讀性也很差,這是不能容忍的。

本文介紹一個開源的golang並行庫,原始碼地址https://github.com/buptmiao/parallel

資料結構

1. parallel結構體

type Parallel struct {
        wg        *sync.WaitGroup
        pipes     []*Pipeline
        wgChild   *sync.WaitGroup
        children  []*Parallel
        exception *Handler
}

parallel定義了一個多工併發例項,主要包括:併發任務管道(pipes)、子任務併發例項(children)、子任務例項等待鎖(wgChild)、當前併發任務例項等待鎖(wg)

2. pipeline結構體

type Pipeline struct {
        handlers []*Handler
}  
type Handler struct {
        f    interface{}
        args []interface{}
        receivers []interface{}
}       

這裡pipeline實際上是一系列併發任務例項handler,每一個handler包括任務函式f, 傳入引數args以及返回結果receivers

parallel相關程式碼

新建parallel例項

func NewParallel() *Parallel {
        res := new(Parallel)
        res.wg = new(sync.WaitGroup)
        res.wgChild = new(sync.WaitGroup)
        res.pipes = make([]*Pipeline, 0, 10)
        return res
}       

註冊handler

func (p *Parallel) Register(f interface{}, args ...interface{}) *Handler {
        return p.NewPipeline().Register(f, args...)
}
func (p *Parallel) NewPipeline() *Pipeline {
        pipe := NewPipeline()
        p.Add(pipe)
        return pipe
} 
func (p *Parallel) Add(pipes ...*Pipeline) *Parallel {
        p.wg.Add(len(pipes))
        p.pipes = append(p.pipes, pipes...)
        return p
}

新建子parallel例項

func (p *Parallel) NewChild() *Parallel {
        child := NewParallel()
        child.exception = p.exception
        p.AddChildren(child)
        return child
}
func (p *Parallel) AddChildren(children ...*Parallel) *Parallel {
        p.wgChild.Add(len(children))
        p.children = append(p.children, children...)
        return p
}

任務執行

func (p *Parallel) Run() {
        for _, child := range p.children {
                // this func will never panic
                go func(ch *Parallel) {
                        ch.Run()
                        p.wgChild.Done()
                }(child)
        }
        p.wgChild.Wait() //wait children instance done
        p.do() //run
        p.wg.Wait() //wait all job done
}
func (p *Parallel) do() {
        for _, pipe := range p.pipes {
                go p.Do()
        }
}

pipeline相關程式碼

新建pipeline例項

func NewPipeline() *Pipeline {
        res := new(Pipeline)
        return res
}       

註冊handler

func (p *Pipeline) Register(f interface{}, args ...interface{}) *Handler {
        h := NewHandler(f, args...)
        p.Add(h)
        return h
}       

新增handler

func (p *Pipeline) Add(hs ...*Handler) *Pipeline {
        p.handlers = append(p.handlers, hs...)
        return p
}

任務執行

func (p *Pipeline) Do() {
        for _, h := range p.handlers {
                h.Do()
        }
}

handler相關程式碼

新建handler例項

func NewHandler(f interface{}, args ...interface{}) *Handler {
        res := new(Handler)
        res.f = f
        res.args = args
        return res
}

執行任務

func (h *Handler) Do() {
        f := reflect.ValueOf(h.f)
        typ := f.Type()
        //check if f is a function
        if typ.Kind() != reflect.Func {
                panic(ErrArgNotFunction)
        }
        //check input length, only check `>` is to allow varargs.
        if typ.NumIn() > len(h.args) {
                panic(ErrInArgLenNotMatch)
        }
        //check output length
        if typ.NumOut() != len(h.receivers) {
                panic(ErrOutArgLenNotMatch)
        }
        //check if output args is ptr
        for _, v := range h.receivers {
                t := reflect.ValueOf(v)
                if t.Type().Kind() != reflect.Ptr {
                        panic(ErrRecvArgTypeNotPtr)
                }
                if t.IsNil() {
                        panic(ErrRecvArgNil)
                }
        }

        inputs := make([]reflect.Value, len(h.args))
        for i := 0; i < len(h.args); i++ {
                if h.args[i] == nil {
                        inputs[i] = reflect.Zero(f.Type().In(i))
                } else {
                        inputs[i] = reflect.ValueOf(h.args[i])
                }
        }
        out := f.Call(inputs)

        for i := 0; i < len(h.receivers); i++ {
                v := reflect.ValueOf(h.receivers[i])
                v.Elem().Set(out[i])
        }
}

demo

package main

import "github.com/buptmiao/parallel"

func testJobA(x, y int) int {
        return x - y
}

func testJobB(x, y int) int {
        return x + y
}

func testJobC(x, y *int, z int) float64 {
        return float64((*x)*(*y)) / float64(z)
}

func main() {
        var x, y int
        var z float64

        p := parallel.NewParallel()

        ch1 := p.NewChild()
        ch1.Register(testJobA, 1, 2).SetReceivers(&x)

        ch2 := p.NewChild()
        ch2.Register(testJobB, 1, 2).SetReceivers(&y)

        p.Register(testJobC, &x, &y, 2).SetReceivers(&z)

        p.Run()

        if x != -1 || y != 3 || z != -1.5 {
                panic("unexpected result")
        }
}


相關文章