玩轉 Go 生態|Hertz WebSocket 擴充套件簡析

白澤來了發表於2022-12-14

WebSocket 是一種可以在單個 TCP 連線上進行全雙工通訊,位於 OSI 模型的應用層。WebSocket 使得客戶端和伺服器之間的資料交換變得更加簡單,允許服務端主動向客戶端推送資料。在 WebSocket API 中,瀏覽器和伺服器只需要完成一次握手,兩者之間就可以建立永續性的連線,並進行雙向資料傳輸。

Hertz 提供了 WebSocket 的支援,參考 gorilla/websocket 庫使用 hijack 的方式在 Hertz 進行了適配,用法和引數基本保持一致。

安裝

go get github.com/hertz-contrib/websocket

示例程式碼

package main
​
import (
    "context"
    "flag"
    "html/template"
    "log"
​
    "github.com/cloudwego/hertz/pkg/app"
    "github.com/cloudwego/hertz/pkg/app/server"
    "github.com/hertz-contrib/websocket"
)
​
var addr = flag.String("addr", "localhost:8080", "http service address")
​
var upgrader = websocket.HertzUpgrader{} // use default options
​
func echo(_ context.Context, c *app.RequestContext) {
    err := upgrader.Upgrade(c, func(conn *websocket.Conn) {
        for {
            mt, message, err := conn.ReadMessage()
            if err != nil {
                log.Println("read:", err)
                break
            }
            log.Printf("recv: %s", message)
            err = conn.WriteMessage(mt, message)
            if err != nil {
                log.Println("write:", err)
                break
            }
        }
    })
    if err != nil {
        log.Print("upgrade:", err)
        return
    }
}
​
func home(_ context.Context, c *app.RequestContext) {
    c.SetContentType("text/html; charset=utf-8")
    homeTemplate.Execute(c, "ws://"+string(c.Host())+"/echo")
}
​
func main() {
    flag.Parse()
    h := server.Default(server.WithHostPorts(*addr))
    // https://github.com/cloudwego/hertz/issues/121
    h.NoHijackConnPool = true
    h.GET("/", home)
    h.GET("/echo", echo)
    h.Spin()
}
​
var homeTemplate = template.Must(template.New("").Parse(`
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<script>  
window.addEventListener("load", function(evt) {
​
    var output = document.getElementById("output");
    var input = document.getElementById("input");
    var ws;
​
    var print = function(message) {
        var d = document.createElement("div");
        d.textContent = message;
        output.appendChild(d);
        output.scroll(0, output.scrollHeight);
    };
​
    document.getElementById("open").onclick = function(evt) {
        if (ws) {
            return false;
        }
        ws = new WebSocket("{{.}}");
        ws.onopen = function(evt) {
            print("OPEN");
        }
        ws.onclose = function(evt) {
            print("CLOSE");
            ws = null;
        }
        ws.onmessage = function(evt) {
            print("RESPONSE: " + evt.data);
        }
        ws.onerror = function(evt) {
            print("ERROR: " + evt.data);
        }
        return false;
    };
​
    document.getElementById("send").onclick = function(evt) {
        if (!ws) {
            return false;
        }
        print("SEND: " + input.value);
        ws.send(input.value);
        return false;
    };
​
    document.getElementById("close").onclick = function(evt) {
        if (!ws) {
            return false;
        }
        ws.close();
        return false;
    };
​
});
</script>
</head>
<body>
<table>
<tr><td valign="top" width="50%">
<p>Click "Open" to create a connection to the server, 
"Send" to send a message to the server and "Close" to close the connection. 
You can change the message and send multiple times.
<p>
<form>
<button id="open">Open</button>
<button id="close">Close</button>
<p><input id="input" type="text" value="Hello world!">
<button id="send">Send</button>
</form>
</td><td valign="top" width="50%">
<div id="output" style="max-height: 70vh;overflow-y: scroll;"></div>
</td></tr></table>
</body>
</html>
`))

執行 server:

go run server.go

上述示例程式碼中,伺服器包括一個簡單的網路客戶端。要使用該客戶端,在瀏覽器中開啟 http://127.0.0.1:8080,並按照頁面上的指示操作。

Upgrade

websocket.Conn 型別代表一個 WebSocket 連線。伺服器應用程式從 HTTP 請求處理程式中呼叫 HertzUpgrader.Upgrade 方法,將 HTTP 協議的連線請求升級為 WebSocket 協議的連線請求。

這部分邏輯對應著示例程式碼的 echo() 函式,此處著重介紹 HertzUpgrader.Upgrade

函式簽名:

func (u *HertzUpgrader) Upgrade(ctx *app.RequestContext, handler HertzHandler) error

內部處理邏輯:

func (u *HertzUpgrader) Upgrade(ctx *app.RequestContext, handler HertzHandler) error {
    if !ctx.IsGet() {
        return u.returnError(ctx, consts.StatusMethodNotAllowed, fmt.Sprintf("%s request method is not GET", badHandshake))
    }
    // 校驗 requsetHeader 中與 websocket 相關的欄位(此處省略部分邏輯程式碼)
​
    subprotocol := u.selectSubprotocol(ctx)
    compress := u.isCompressionEnable(ctx)
​
    ctx.SetStatusCode(consts.StatusSwitchingProtocols)
    // 構造協議升級後的響應頭部資訊
    ctx.Response.Header.Set("Upgrade", "websocket")
    ctx.Response.Header.Set("Connection", "Upgrade")
    ctx.Response.Header.Set("Sec-WebSocket-Accept", computeAcceptKeyBytes(challengeKey))
    // “無上下文接管”模式
    if compress {
        ctx.Response.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
    }
    if subprotocol != nil {
        ctx.Response.Header.SetBytesV("Sec-WebSocket-Protocol", subprotocol)
    }
​
    // 透過 Hijack 的方式,實現 websocket 全雙工的通訊
    ctx.Hijack(func(netConn network.Conn) {
        writeBuf := poolWriteBuffer.Get().([]byte)
        c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, nil, writeBuf)
        if subprotocol != nil {
            c.subprotocol = b2s(subprotocol)
        }
​
        if compress {
            c.newCompressionWriter = compressNoContextTakeover
            c.newDecompressionReader = decompressNoContextTakeover
        }
​
        netConn.SetDeadline(time.Time{})
​
        handler(c)
​
        writeBuf = writeBuf[0:0]
        poolWriteBuffer.Put(writeBuf)
    })
​
    return nil
}

HertzHandler

HertzHandler 是上述 HertzUpgrader.Upgrade 函式的第二個引數。HertzHandler 在握手完成後接收一個 websocket 連線,透過劫持這個連線,完成全雙工的通訊。

HertzHandler 必須由使用者提供,內部定義了 WebSocket 請求和響應的具體流程。

函式簽名:

type HertzHandler func(*Conn)

上述 echo 伺服器的 websocket 處理流程:

err := upgrader.Upgrade(c, func(conn *websocket.Conn) {
    for {
        // 讀取客戶端傳送的資訊
        mt, message, err := conn.ReadMessage()
        if err != nil {
            log.Println("read:", err)
            break
        }
        log.Printf("recv: %s", message)
        // 向客戶端傳送資訊
        err = conn.WriteMessage(mt, message)
        if err != nil {
            log.Println("write:", err)
            break
        }
    }
})

配置

上述文件已經講述了Hertz WebSocket 最核心的協議升級連線劫持的邏輯,下面將羅列 Hertz WebSocket 使用過程中可選的配置引數。

這部分將圍繞 websocket.HertzUpgrader 結構展開說明。

引數 介紹
ReadBufferSize 用於設定輸入緩衝區的大小,單位為位元組。如果緩衝區大小為零,那麼就使用 HTTP 伺服器分配的大小。輸入緩衝區大小並不限制可以接收的資訊的大小。
WriteBufferSize 用於設定輸出緩衝區的大小,單位為位元組。如果緩衝區大小為零,那麼就使用 HTTP 伺服器分配的大小。輸出緩衝區大小並不限制可以傳送的資訊的大小。
WriteBufferPool 用於設定寫操作的緩衝池。
Subprotocols 用於按優先順序設定伺服器支援的協議。如果這個欄位不是 nil,那麼 Upgrade 方法透過選擇這個列表中與客戶端請求的協議的第一個匹配來協商一個子協議。如果沒有匹配,那麼就不協商協議(Sec-Websocket-Protocol 頭不包括在握手響應中)。
Error 用於設定生成 HTTP 錯誤響應的函式。
CheckOrigin 用於設定針對請求的 Origin 頭的校驗函式, 如果請求的 Origin 頭是可接受的,CheckOrigin 返回 true。
EnableCompression 用於設定伺服器是否應該嘗試協商每個訊息的壓縮(RFC 7692)。將此值設定為 true 並不能保證壓縮會被支援。

WriteBufferPool

如果該值沒有被設定,則額外初始化寫緩衝區,並在當前生命週期內分配給該連線。當應用程式在大量的連線上有適度的寫入量時,緩衝池是最有用的。

應用程式應該使用一個單一的緩衝池來為不同的連線分配緩衝區。

介面簽名:

// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
// interface.  The type of the value stored in a pool is not specified.
type BufferPool interface {
    // Get gets a value from the pool or returns nil if the pool is empty.
    Get() interface{}
    // Put adds a value to the pool.
    Put(interface{})
}

示例程式碼:

type simpleBufferPool struct {
    v interface{}
}
​
func (p *simpleBufferPool) Get() interface{} {
    v := p.v
    p.v = nil
    return v
}
​
func (p *simpleBufferPool) Put(v interface{}) {
    p.v = v
}
​
var upgrader = websocket.HertzUpgrader{
    WriteBufferPool: &simpleBufferPool{},
}

Subprotocols

WebSocket 只是定義了一種交換任意訊息的機制。這些訊息是什麼意思,客戶端在任何特定的時間點可以期待什麼樣的訊息,或者他們被允許傳送什麼樣的訊息,完全取決於實現應用程式。

所以你需要在伺服器和客戶端之間就這些事情達成協議。子協議引數只是讓客戶端和服務端正式地交換這些資訊。你可以為你想要的任何協議編造任何名字。伺服器可以簡單地檢查客戶在握手過程中是否遵守了該協議。

Error

如果 Error 為 nil,則使用 Hertz 提供的 API 來生成 HTTP 錯誤響應。

函式簽名:

func(ctx *app.RequestContext, status int, reason error)

示例程式碼:

var upgrader = websocket.HertzUpgrader{
    Error: func(ctx *app.RequestContext, status int, reason error) {
        ctx.Response.Header.Set("Sec-Websocket-Version", "13")
        ctx.AbortWithMsg(reason.Error(), status)
    },
}

CheckOrigin

如果 CheckOrigin 為nil,則使用一個安全的預設值:如果Origin請求頭存在,並且源主機不等於請求主機頭,則返回false。CheckOrigin 函式應該仔細驗證請求的來源,以防止跨站請求偽造。

函式簽名:

func(ctx *app.RequestContext) bool

預設實現:

func fastHTTPCheckSameOrigin(ctx *app.RequestContext) bool {
    origin := ctx.Request.Header.Peek("Origin")
    if len(origin) == 0 {
        return true
    }
    u, err := url.Parse(b2s(origin))
    if err != nil {
        return false
    }
    return equalASCIIFold(u.Host, b2s(ctx.Host()))
}

EnableCompression

服務端接受一個或者多個擴充套件欄位,這些擴充套件欄位是包含客戶端請求的 Sec-WebSocket-Extensions 頭欄位擴充套件中的。當 EnableCompression 為 true 時,服務端根據當前自身支援的擴充套件與其進行匹配,如果匹配成功則支援壓縮。

校驗邏輯:

var strPermessageDeflate = []byte("permessage-deflate")
​
func (u *HertzUpgrader) isCompressionEnable(ctx *app.RequestContext) bool {
    extensions := parseDataHeader(ctx.Request.Header.Peek("Sec-WebSocket-Extensions"))
​
    // Negotiate PMCE
    if u.EnableCompression {
        for _, ext := range extensions {
            if bytes.HasPrefix(ext, strPermessageDeflate) {
                return true
            }
        }
    }
​
    return false
}

目前僅支援“無上下文接管”模式,詳見上述 HertzUpgrader.Upgrade 程式碼部分。

Set Deadline

當使用 websocket 進行讀寫的時候,可以透過類似如下方式設定超時時間(在每次讀寫過程中都會生效)。

示例程式碼:

func echo(_ context.Context, c *app.RequestContext) {
    err := upgrader.Upgrade(c, func(conn *websocket.Conn) {
        defer conn.Close()
        // "github.com/cloudwego/hertz/pkg/network"
        conn.NetConn().(network.Conn).SetReadTimeout(1 * time.Second)
        ...
    })
    if err != nil {
        log.Print("upgrade:", err)
        return
    }
}

更多用法示例詳見 examples

相關文章