Golang 實現 Redis(6): 實現 pipeline 模式的 redis 客戶端

發表於2020-11-24

本文是使用 golang 實現 redis 系列的第六篇, 將介紹如何實現一個 Pipeline 模式的 Redis 客戶端。

本文的完整程式碼在Github:Godis/redis/client

通常 TCP 客戶端的通訊模式都是阻塞式的: 客戶端傳送請求 -> 等待服務端響應 -> 傳送下一個請求。因為需要等待網路傳輸資料,完成一次請求迴圈需要等待較多時間。

我們能否不等待服務端響應直接傳送下一條請求呢?答案是肯定的。

TCP 作為全雙工協議可以同時進行上行和下行通訊,不必擔心客戶端和服務端同時發包會導致衝突。

p.s. 打電話的時候兩個人同時講話就會衝突聽不清,只能輪流講。這種通訊方式稱為半雙工。廣播只能由電臺傳送到收音機不能反向傳輸,這種方式稱為單工。

我們為每一個 tcp 連線分配了一個 goroutine 可以保證先收到的請求先先回復。另一個方面,tcp 協議會保證資料流的有序性,同一個 tcp 連線上先傳送的請求服務端先接收,先回復的響應客戶端先收到。因此我們不必擔心混淆響應所對應的請求。

這種在服務端未響應時客戶端繼續向服務端傳送請求的模式稱為 Pipeline 模式。因為減少等待網路傳輸的時間,Pipeline 模式可以極大的提高吞吐量,減少所需使用的 tcp 連結數。

pipeline 模式的 redis 客戶端需要有兩個後臺協程程負責 tcp 通訊,呼叫方通過 channel 向後臺協程傳送指令,並阻塞等待直到收到響應,這是一個典型的非同步程式設計模式。

我們先來定義 client 的結構:

type Client struct {
    conn        net.Conn // 與服務端的 tcp 連線
    sendingReqs chan *Request // 等待傳送的請求
    waitingReqs chan *Request // 等待伺服器響應的請求
    ticker      *time.Ticker // 用於觸發心跳包的計時器
    addr        string

    ctx        context.Context
    cancelFunc context.CancelFunc
    writing    *sync.WaitGroup // 有請求正在處理不能立即停止,用於實現 graceful shutdown
}

type Request struct {
    id        uint64 // 請求id
    args      [][]byte // 上行引數
    reply     redis.Reply // 收到的返回值
    heartbeat bool // 標記是否是心跳請求
    waiting   *wait.Wait // 呼叫協程傳送請求後通過 waitgroup 等待請求非同步處理完成
    err       error
}

呼叫者將請求傳送給後臺協程,並通過 wait group 等待非同步處理完成:

func (client *Client) Send(args [][]byte) redis.Reply {
    request := &Request{
        args:      args,
        heartbeat: false,
        waiting:   &wait.Wait{},
    }
    request.waiting.Add(1) 
    client.sendingReqs <- request // 將請求發往處理佇列
    timeout := request.waiting.WaitWithTimeout(maxWait) // 等待請求處理完成或者超時
    if timeout {
        return reply.MakeErrReply("server time out")
    }
    if request.err != nil {
        return reply.MakeErrReply("request failed: " + err.Error())
    }
    return request.reply
}

client 的核心部分是後臺的讀寫協程。先從寫協程開始:

// 寫協程入口
func (client *Client) handleWrite() {
loop:
    for {
        select {
        case req := <-client.sendingReqs: // 從 channel 中取出請求
            client.writing.Add(1) // 未完成請求數+1
            client.doRequest(req) // 傳送請求
        case <-client.ctx.Done():
            break loop
        }
    }
}

// 傳送請求
func (client *Client) doRequest(req *Request) {
    bytes := reply.MakeMultiBulkReply(req.args).ToBytes() // 序列化
    _, err := client.conn.Write(bytes) // 通過 tcp connection 傳送
    i := 0
    for err != nil && i < 3 { // 失敗重試
        err = client.handleConnectionError(err) 
        if err == nil {
            _, err = client.conn.Write(bytes)
        }
        i++
    }
    if err == nil {
        client.waitingReqs <- req // 將傳送成功的請求放入等待響應的佇列
    } else {
        // 傳送失敗
        req.err = err
        req.waiting.Done() // 結束呼叫者的等待
        client.writing.Done() // 未完成請求數 -1
    }
}

讀協程是我們熟悉的協議解析器模板, 不熟悉的朋友可以到實現 Redis 協議解析器瞭解更多。

// 收到服務端的響應
func (client *Client) finishRequest(reply redis.Reply) {
    request := <-client.waitingReqs // 取出等待響應的 request
    request.reply = reply
    if request.waiting != nil {
        request.waiting.Done() // 結束呼叫者的等待
    }
    client.writing.Done() // 未完成請求數-1
}

// 讀協程是個 RESP 協議解析器,不熟悉的朋友可以
func (client *Client) handleRead() error {
    reader := bufio.NewReader(client.conn)
    downloading := false
    expectedArgsCount := 0
    receivedCount := 0
    msgType := byte(0) // first char of msg
    var args [][]byte
    var fixedLen int64 = 0
    var err error
    var msg []byte
    for {
        // read line
        if fixedLen == 0 { // read normal line
            msg, err = reader.ReadBytes('\n')
            if err != nil {
                if err == io.EOF || err == io.ErrUnexpectedEOF {
                    logger.Info("connection close")
                } else {
                    logger.Warn(err)
                }

                return errors.New("connection closed")
            }
            if len(msg) == 0 || msg[len(msg)-2] != '\r' {
                return errors.New("protocol error")
            }
        } else { // read bulk line (binary safe)
            msg = make([]byte, fixedLen+2)
            _, err = io.ReadFull(reader, msg)
            if err != nil {
                if err == io.EOF || err == io.ErrUnexpectedEOF {
                    return errors.New("connection closed")
                } else {
                    return err
                }
            }
            if len(msg) == 0 ||
                msg[len(msg)-2] != '\r' ||
                msg[len(msg)-1] != '\n' {
                return errors.New("protocol error")
            }
            fixedLen = 0
        }

        // parse line
        if !downloading {
            // receive new response
            if msg[0] == '*' { // multi bulk response
                // bulk multi msg
                expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
                if err != nil {
                    return errors.New("protocol error: " + err.Error())
                }
                if expectedLine == 0 {
                    client.finishRequest(&reply.EmptyMultiBulkReply{})
                } else if expectedLine > 0 {
                    msgType = msg[0]
                    downloading = true
                    expectedArgsCount = int(expectedLine)
                    receivedCount = 0
                    args = make([][]byte, expectedLine)
                } else {
                    return errors.New("protocol error")
                }
            } else if msg[0] == '$' { // bulk response
                fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
                if err != nil {
                    return err
                }
                if fixedLen == -1 { // null bulk
                    client.finishRequest(&reply.NullBulkReply{})
                    fixedLen = 0
                } else if fixedLen > 0 {
                    msgType = msg[0]
                    downloading = true
                    expectedArgsCount = 1
                    receivedCount = 0
                    args = make([][]byte, 1)
                } else {
                    return errors.New("protocol error")
                }
            } else { // single line response
                str := strings.TrimSuffix(string(msg), "\n")
                str = strings.TrimSuffix(str, "\r")
                var result redis.Reply
                switch msg[0] {
                case '+':
                    result = reply.MakeStatusReply(str[1:])
                case '-':
                    result = reply.MakeErrReply(str[1:])
                case ':':
                    val, err := strconv.ParseInt(str[1:], 10, 64)
                    if err != nil {
                        return errors.New("protocol error")
                    }
                    result = reply.MakeIntReply(val)
                }
                client.finishRequest(result)
            }
        } else {
            // receive following part of a request
            line := msg[0 : len(msg)-2]
            if line[0] == '$' {
                fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
                if err != nil {
                    return err
                }
                if fixedLen <= 0 { // null bulk in multi bulks
                    args[receivedCount] = []byte{}
                    receivedCount++
                    fixedLen = 0
                }
            } else {
                args[receivedCount] = line
                receivedCount++
            }

            // if sending finished
            if receivedCount == expectedArgsCount {
                downloading = false // finish downloading progress

                if msgType == '*' {
                    reply := reply.MakeMultiBulkReply(args)
                    client.finishRequest(reply)
                } else if msgType == '$' {
                    reply := reply.MakeBulkReply(args[0])
                    client.finishRequest(reply)
                }


                // finish reply
                expectedArgsCount = 0
                receivedCount = 0
                args = nil
                msgType = byte(0)
            }
        }
    }
}

最後編寫 client 的構造器和啟動非同步協程的程式碼:

func MakeClient(addr string) (*Client, error) {
    conn, err := net.Dial("tcp", addr)
    if err != nil {
        return nil, err
    }
    ctx, cancel := context.WithCancel(context.Background())
    return &Client{
        addr:        addr,
        conn:        conn,
        sendingReqs: make(chan *Request, chanSize),
        waitingReqs: make(chan *Request, chanSize),
        ctx:         ctx,
        cancelFunc:  cancel,
        writing:     &sync.WaitGroup{},
    }, nil
}

func (client *Client) Start() {
    client.ticker = time.NewTicker(10 * time.Second)
    go client.handleWrite()
    go func() {
        err := client.handleRead()
        logger.Warn(err)
    }()
    go client.heartbeat()
}

關閉 client 的時候記得等待請求完成:

func (client *Client) Close() {
    // 先阻止新請求進入佇列
    close(client.sendingReqs)

    // 等待處理中的請求完成
    client.writing.Wait()

    // 釋放資源
    _ = client.conn.Close() // 關閉與服務端的連線,連線關閉後讀協程會退出
    client.cancelFunc() // 使用 context 關閉讀協程
    close(client.waitingReqs) // 關閉佇列
}

測試一下:

func TestClient(t *testing.T) {
    client, err := MakeClient("localhost:6379")
    if err != nil {
        t.Error(err)
    }
    client.Start()

    result = client.Send([][]byte{
        []byte("SET"),
        []byte("a"),
        []byte("a"),
    })
    if statusRet, ok := result.(*reply.StatusReply); ok {
        if statusRet.Status != "OK" {
            t.Error("`set` failed, result: " + statusRet.Status)
        }
    }

    result = client.Send([][]byte{
        []byte("GET"),
        []byte("a"),
    })
    if bulkRet, ok := result.(*reply.BulkReply); ok {
        if string(bulkRet.Arg) != "a" {
            t.Error("`get` failed, result: " + string(bulkRet.Arg))
        }
    }
}

相關文章