64行程式碼實現零拷貝go的TCP拆包粘包

weixin_34236869發表於2018-01-01

64行程式碼實現零拷貝go的TCP拆包粘包

前言

這段時間想用go寫一個簡單IM系統,就思考了一下go語言TCP的拆包粘包。TCP的拆包粘包有一般有三種解決方案。

使用定長位元組

實際使用中,少於固定字長的,要用字元去填充,空間使用率不夠高。

使用分隔符

一般用文字傳輸的,使用分隔符,IM系統一般對效能要求高,不推薦使用文字傳輸。

用訊息的頭位元組標識訊息內容的長度

可以使用二進位制傳輸,效率高,推薦。下面看看怎麼實現。

嘗試使用系統庫自帶的bytes.Buffer實現

程式碼實現:

package tcp

import (
    "fmt"
    "net"
    "log"
    "bytes"
    "encoding/binary"
)

const (
    BYTES_SIZE uint16 = 1024
    HEAD_SIZE  int    = 2
)

func StartServer(address string) {
    listener, err := net.Listen("tcp", address)
    if err != nil {
        log.Println("Error listening", err.Error())
        return
    }
    for {
        conn, err := listener.Accept()
        fmt.Println(conn.RemoteAddr())
        if err != nil {
            fmt.Println("Error accepting", err.Error())
            return // 終止程式
        }
        go doConn(conn)
    }
}

func doConn(conn net.Conn) {
    var (
        buffer           = bytes.NewBuffer(make([]byte, 0, BYTES_SIZE))
        bytes            = make([]byte, BYTES_SIZE);
        isHead      bool = true
        contentSize int
        head        = make([]byte, HEAD_SIZE)
        content     = make([]byte, BYTES_SIZE)
    )
    for {
        readLen, err := conn.Read(bytes);
        if err != nil {
            log.Println("Error reading", err.Error())
            return
        }
        _, err = buffer.Write(bytes[0:readLen])
        if err != nil {
            log.Println("Error writing to buffer", err.Error())
            return
        }

        for {
            if isHead {
                if buffer.Len() >= HEAD_SIZE {
                    _, err := buffer.Read(head)
                    if err != nil {
                        fmt.Println("Error reading", err.Error())
                        return
                    }
                    contentSize = int(binary.BigEndian.Uint16(head))
                    isHead = false
                } else {
                    break
                }
            }
            if !isHead {
                if buffer.Len() >= contentSize {
                    _, err := buffer.Read(content[:contentSize])
                    if err != nil {
                        fmt.Println("Error reading", err.Error())
                        return
                    }
                    fmt.Println(string(content[:contentSize]))
                    isHead = true
                } else {
                    break
                }
            }
        }
    }
}

測試用例:

package tcp

import (
    "testing"
    "net"
    "fmt"
    "encoding/binary"
)

func TestStartServer(t *testing.T) {
    StartServer("localhost:50002")
}

func TestClient(t *testing.T) {
    conn, err := net.Dial("tcp", "localhost:50002")
    if err != nil {
        fmt.Println("Error dialing", err.Error())
        return // 終止程式
    }
    var headSize int
    var headBytes = make([]byte, 2)
    s := "hello world"
    content := []byte(s)
    headSize = len(content)
    binary.BigEndian.PutUint16(headBytes, uint16(headSize))
    conn.Write(headBytes)
    conn.Write(content)

    s = "hello go"
    content = []byte(s)
    headSize = len(content)
    binary.BigEndian.PutUint16(headBytes, uint16(headSize))
    conn.Write(headBytes)
    conn.Write(content)

    s = "hello tcp"
    content = []byte(s)
    headSize = len(content)
    binary.BigEndian.PutUint16(headBytes, uint16(headSize))
    conn.Write(headBytes)
    conn.Write(content)
}

執行結果

127.0.0.1:51062
hello world
hello go
hello tcp

用go系統庫的buffer,是不是感覺程式碼特別彆扭,兩大缺點

1.要寫大量的邏輯程式碼,來彌補buffer對這個場景的不適用。

2.效能不高,有三次次記憶體拷貝,coon->[]byte->Buffer->[]byte。

自己實現

既然輪子不合適,就自己造輪子,首先實現一個自己的Buffer,很簡單,只有六十幾行程式碼,所有過程只有一次byte陣列的拷貝,conn->buffer,剩下的全部操作都在原buffer的位元組陣列裡面操作

package tcp

import (
    "errors"
    "io"
)

type buffer struct {
    reader io.Reader
    buf    []byte
    start  int
    end    int
}

func newBuffer(reader io.Reader, len int) buffer {
    buf := make([]byte, len)
    return buffer{reader, buf, 0, 0}
}

func (b *buffer) Len() int {
    return b.end - b.start
}

//將有用的位元組前移
func (b *buffer) grow() {
    if b.start == 0 {
        return
    }
    copy(b.buf, b.buf[b.start:b.end])
    b.end -= b.start
    b.start = 0;
}

//從reader裡面讀取資料,如果reader阻塞,會發生阻塞
func (b *buffer) readFromReader() (int, error) {
    b.grow()
    n, err := b.reader.Read(b.buf[b.end:])
    if (err != nil) {
        return n, err
    }
    b.end += n
    return n, nil
}

//返回n個位元組,而不產生移位
func (b *buffer) seek(n int) ([]byte, error) {
    if b.end-b.start >= n {
        buf := b.buf[b.start:b.start+n]
        return buf, nil
    }
    return nil, errors.New("not enough")
}

//捨棄offset個欄位,讀取n個欄位
func (b *buffer) read(offset, n int) ([]byte) {
    b.start += offset
    buf := b.buf[b.start:b.start+n]
    b.start += n
    return buf
}

再看看怎樣使用它,將上面的doConn函式改成這樣就行了。

func doConn(conn net.Conn) {
    var (
        buffer      = newBuffer(conn, 16)
        headBuf     []byte
        contentSize int
        contentBuf  []byte
    )
    for {
        _, err := buffer.readFromReader()
        if err != nil {
            fmt.Println(err)
            return
        }
        for {
            headBuf, err = buffer.seek(HEAD_SIZE);
            if err != nil {
                break
            }
            contentSize = int(binary.BigEndian.Uint16(headBuf))
            if (buffer.Len() >= contentSize-HEAD_SIZE) {
                contentBuf = buffer.read(HEAD_SIZE, contentSize)
                fmt.Println(string(contentBuf))
                continue
            }
            break
        }
    }
}

跑下測試用例,看下結果

127.0.0.1:51062
hello world
hello go
hello tcp

原始碼地址:https://github.com/alberliu/goim

你有更好的方式,可以郵箱我,alber_liu@qq.com,讓我學習一下

相關文章