WebSocket 實現原理淺析

zeeyang.com發表於2017-07-05

背景

之前我們將 CocoaAsyncSocket 作為底層實現,在其上面封裝了一套 Socket 通訊機制以及業務介面,最近我們開始研究 WebSocket ,並用來替換掉原先的 CocoaAsyncSocket ,簡單來說一下兩者的關係,WebSocket 和 Socket 雖然名稱上很像,但兩者是完全不同的東西, WebSocket 是建立在 TCP/IP 協議之上,屬於應用層的協議,而 Socket 是在應用層和傳輸層中的一個抽象層,它是將 TCP/IP 層的複雜操作抽象成幾個簡單的介面來提供給應用層呼叫。為什麼要做這次替換呢?原因是我們服務端在做改造,同時網頁版 IM 已經使用了 WebSocket ,客戶端也採用的話對於服務端來說維護一套程式碼會更好更方便,而且 WebSocket 在體積、實時性和擴充套件上都具有一定的優勢。

WebSocket 最新的協議是 13 RFC 6455 ,要理解 WebSocket 的實現,一定要去理解它的協議!~

WebSocket 的實現分為握手,資料傳送/讀取

握手

握手要從請求頭去理解。

WebSocket 首先發起一個 HTTP 請求,在請求頭加上 Upgrade 欄位,該欄位用於改變 HTTP 協議版本或者是換用其他協議,這裡我們把 Upgrade 的值設為 websocket ,將它升級為 WebSocket 協議。

同時要注意 Sec-WebSocket-Key 欄位,它由客戶端生成併發給服務端,用於證明服務端接收到的是一個可受信的連線握手,可以幫助服務端排除自身接收到的由非 WebSocket 客戶端發起的連線,該值是一串隨機經過 base64 編碼的字串。

GET /chat HTTP/1.1
Host: server.example.com
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Origin: http://example.com
Sec-WebSocket-Protocol: chat, superchat
Sec-WebSocket-Version: 13

我們可以簡化請求頭,將請求以字串方式傳送出去,當然別忘了最後的兩個空行作為包結束:

const char * fmt = "GET %s HTTP/1.1\r\n"
                   "Upgrade: websocket\r\n"
                   "Connection: Upgrade\r\n"
                   "Host: %s\r\n"
                   "Sec-WebSocket-Key: %s\r\n"
                   "Sec-WebSocket-Version: 13\r\n"
                   "\r\n";
size = strlen(fmt) + strlen(path) + strlen(host) + strlen(ws->key);
buf = (char *)malloc(size);
sprintf(buf, fmt, path, host, ws->key);
size = strlen(buf);
nbytes = ws->io_send(ws, ws->context, buf, size);

收到請求後,服務端也會做一次響應:

HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=

裡面重要的是 Sec-WebSocket-Accept ,服務端通過從客戶端請求頭中讀取 Sec-WebSocket-Key 與一串全域性唯一的標識字串(俗稱魔串)“258EAFA5-E914-47DA- 95CA-C5AB0DC85B11”做拼接,生成長度為160位元組的 SHA-1 字串,然後進行 base64 編碼,作為 Sec-WebSocket-Accept 的值回傳給客戶端。

處理握手 HTTP 響應解析的時候,可以用 nodejs 的 http-paser ,解析方式也比較簡單,就是對頭資訊的逐字讀取再處理,具體處理你可以看一下它的狀態機實現。解析完成後你需要對其內容進行解析,看返回是否正確,同時去管理你的握手狀態。

資料傳送/讀取

資料的處理就要拿這個幀協議圖來說明了:

0                   1                   2                   3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-------+-+-------------+-------------------------------+
|F|R|R|R| opcode|M| Payload len |    Extended payload length    |
|I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
|N|V|V|V|       |S|             |   (if payload len==126/127)   |
| |1|2|3|       |K|             |                               |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|     Extended payload length continued, if payload len == 127  |
+ - - - - - - - - - - - - - - - +-------------------------------+
|                               |Masking-key, if MASK set to 1  |
+-------------------------------+-------------------------------+
| Masking-key (continued)       |          Payload Data         |
+-------------------------------- - - - - - - - - - - - - - - - +
:                     Payload Data continued ...                :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|                     Payload Data continued ...                |
+---------------------------------------------------------------+

首先我們來看看數字的含義,數字表示位,0-7表示有8位,等於1個位元組。

0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1

所以如果要組裝一個幀資料可以這樣子:

char *rev = (rev *)malloc(4);
rev[0] = (char)(0x81 & 0xff);
rev[1] = 126 & 0x7f;
rev[2] = 1;
rev[3] = 0;

ok,瞭解了幀資料的樣子,我們反過來去理解值對應的幀欄位。

首先0x81是什麼,這個是十六進位制資料,轉換成二進位制就是1000 0001, 是一個位元組的長度,也就是這一段裡面每一位的值:

0 1 2 3 4 5 6 7 8 
+-+-+-+-+-------+
|F|R|R|R| opcode|
|I|S|S|S|  (4)  |
|N|V|V|V|       |
| |1|2|3|		|
+-+-+-+-+-------+
  • FIN 表示該幀是不是訊息的最後一幀,1表示結束,0表示還有下一幀。
  • RSV1, RSV2, RSV3 必須為0,除非擴充套件協商定義了一個非0的值,如果沒有定義非0值,且收到了非0的 RSV ,那麼 WebSocket 的連線會失效。
  • opcode 用來描述 Payload data 的定義,如果收到了一個未知的 opcode ,同樣會使 WebSocket 連線失效,協議定義了以下值:
    • %x0 表示連續的幀
    • %x1 表示 text 幀
    • %x2 表示二進位制幀
    • %x3-7 預留給非控制幀
    • %x8 表示關閉連線幀
    • %x9 表示 ping
    • %xA 表示 pong
    • %xB-F 預留給控制幀

0xff 作用就是取出需要的二進位制值。

下面再來看126,126則表示的是 Payload len ,也就是 Payload 的長度:

                8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
                +-+-------------+-------------------------------+
                |M| Payload len |    Extended payload length    |
                |A|     (7)     |             (16/64)           |
                |S|             |   (if payload len==126/127)   |
                |K|             |                               |
+-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|     Extended payload length continued, if payload len == 127  |
+ - - - - - - - - - - - - - - - +-------------------------------+
|                               |Masking-key, if MASK set to 1  |
+-------------------------------+-------------------------------+
| Masking-key (continued)       |           Payload Data         |
+-------------------------------- - - - - - - - - - - - - - - - +
:                     Payload Data continued ...                :
+ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|                     Payload Data continued ...                |
+---------------------------------------------------------------+
  • MASK 表示Playload data 是否要加掩碼,如果設成1,則需要賦值 Masking-key 。所有從客戶端發到服務端的幀都要加掩碼
  • Playload len 表示 Payload 的長度,這裡分為三種情況
    • 長度小於126,則只需要7位
    • 長度是126,則需要額外2個位元組的大小,也就是 Extended payload length
    • 長度是127,則需要額外8個位元組的大小,也就是 Extended payload length + Extended payload length continued ,Extended payload length 是2個位元組,Extended payload length continued 是6個位元組
  • Playload len 則表示 Extension data 與 Application data 的和

而資料的傳送和讀取就是對幀的封裝和解析。

資料傳送:

int ws_recv(websocket_t *ws) {
    if (ws->state != WS_STATE_HANDSHAKE_COMPLETED) {
        return ws_do_handshake(ws);
    }

    int ret;
    while(TRUE) {
        ret = ws__recv(ws);
        if (ret != OK) {
            break;
        }
    }
    return ret;
}
int ws__recv(websocket_t *ws) {
    int nbytes;
    int ret = OK, i;
    int state = ws->rd_state;
    char *rd_buf;
    uint64_t rd_buf_len = 0;
    switch(state) {
        case WS_READ_IDLE: {
            if (ws->buf_pos < 2) {
                rd_buf_len = 2 - ws->buf_pos;
                rd_buf = malloc(rd_buf_len);
                nbytes = ws->io_recv(ws, ws->context, rd_buf, (size_t) (rd_buf_len));
                if (nbytes < 0) {
                    free(rd_buf);
                    //TODO errono fix
                    ret = nbytes;
                    break;
                }
                ws__enqueue_buf(ws, rd_buf, (size_t)nbytes) ;
                free(rd_buf);
            }
            if (ws->buf_pos < 2) {
                ret = WS_WANT_READ;
                break;
            }
            ws_frame_t * frame;
            if (ws->frame == NULL) {
                frame__alloc(&ws->frame);
                frame = ws->frame;
            } else {
                frame = ws->frame;
            }
            rd_buf = ws->buf;
            frame->fin = (*(rd_buf) & 0x80) == 0x80 ? 1 : 0;
            frame->op_code = *(rd_buf) & 0x0f;
            frame->payload_len = *(rd_buf + 1) & 0x7f;
            if (frame->payload_len < 126) {
                frame->payload_bit_offset = 2;
                ws->rd_state = WS_READ_PAYLOAD;
            } else if (frame -> payload_len == 126) {
                frame->payload_bit_offset = 4;
                ws->rd_state = WS_READ_EXTEND_PAYLOAD_2_WORDS;
            } else {
                frame->payload_bit_offset = 8;
                ws->rd_state = WS_READ_EXTEND_PAYLOAD_8_WORDS;
            }
            ws__reset_buf(ws, 2);
            break;
        }
        case WS_READ_EXTEND_PAYLOAD_2_WORDS: {
#define PAYLOAD_LEN_BITS 2
            if (ws->buf_pos < PAYLOAD_LEN_BITS) {
                rd_buf_len = PAYLOAD_LEN_BITS - ws->buf_pos;
                rd_buf = malloc(rd_buf_len);
                nbytes = ws->io_recv(ws, ws->context, rd_buf, (size_t) (rd_buf_len));
                if (nbytes < 0) {
                    free(rd_buf);
                    ret = nbytes;
                    break;
                }
                ws__enqueue_buf(ws, rd_buf, (size_t)nbytes) ;
                free(rd_buf);
            }
            if (ws->buf_pos < PAYLOAD_LEN_BITS) {
                ret = WS_WANT_READ;
                break;
            }
            rd_buf = ws->buf;
            ws_frame_t * frame = ws->frame;
            //rd_buf[0] = 0; rd_buf[1] = 255
            for (i = 0; i < PAYLOAD_LEN_BITS; i++) {
                *(((char *)&frame->payload_len) + i) = rd_buf[PAYLOAD_LEN_BITS - 1 - i];
            }
            ws__reset_buf(ws, PAYLOAD_LEN_BITS);
            ws->rd_state = WS_READ_PAYLOAD;
#undef PAYLOAD_LEN_BITS
            break;
        }
        case WS_READ_EXTEND_PAYLOAD_8_WORDS: {
#define PAYLOAD_LEN_BITS 8
            if (ws->buf_pos < PAYLOAD_LEN_BITS) {
                rd_buf_len = PAYLOAD_LEN_BITS - ws->buf_pos;
                rd_buf = malloc(rd_buf_len);
                nbytes = ws->io_recv(ws, ws->context, rd_buf, (size_t) (rd_buf_len));
                if (nbytes < 0) {
                    free(rd_buf);
                    ret = nbytes;
                    break;
                }
                ws__enqueue_buf(ws, rd_buf, (size_t)nbytes) ;
                free(rd_buf);
            }
            if (ws->buf_pos < PAYLOAD_LEN_BITS) {
                ret = WS_WANT_READ;
                break;
            }
            rd_buf = ws->buf;
            ws_frame_t * frame = ws->frame;
            for (i = 0; i < PAYLOAD_LEN_BITS; i++) {
                *(((char *)&frame->payload_len) + i) = rd_buf[PAYLOAD_LEN_BITS - 1 - i];
            }
            ws__reset_buf(ws, PAYLOAD_LEN_BITS);
            ws->rd_state = WS_READ_PAYLOAD;
#undef PAYLOAD_LEN_BITS
            break;
        }
        case WS_READ_PAYLOAD: {
            ws_frame_t * frame = ws->frame;
            uint64_t payload_len = frame->payload_len;
            if (ws->buf_pos < payload_len) {
                rd_buf_len = payload_len - ws->buf_pos;
                rd_buf = malloc(rd_buf_len);
                nbytes = ws->io_recv(ws, ws->context, rd_buf, (size_t) (rd_buf_len));
                if (nbytes < 0) {
                    free(rd_buf);
                    ret = nbytes;
                    break;
                }
                ws__enqueue_buf(ws, rd_buf, (size_t)nbytes) ;
                free(rd_buf);
            }
            if (ws->buf_pos < payload_len) {
                ret = WS_WANT_READ;
                break;
            }
            rd_buf = ws->buf;
            frame->payload = malloc(payload_len);
            memcpy(frame->payload, rd_buf, payload_len);
            ws__reset_buf(ws, payload_len);
            if (frame->fin == 1) {
                // is control frame
                if (frame->op_code == OP_CLOSE) {
                    // TODO if should response a close frame
                    // close connection
                    if (ws->close_cb) {
                        ws->close_cb(ws);
                    }
                } else {
                    ws__dispatch_msg(ws, frame);
                    ws->frame = NULL;
                }
            } else {
                ws_frame_t *new_frame;
                frame__alloc(&new_frame);
                frame->next = new_frame;
                new_frame->prev = frame;
                ws->frame = new_frame;
            }
            ws->rd_state = WS_READ_IDLE;
            break;
        }
    }
    return ret;
}

資料解析:

void ws__wrap_packet(_WS_IN websocket_t *ws,
                     _WS_IN const char *payload,
                     _WS_IN unsigned long long payload_size,
                     _WS_IN int flags,
                     _WS_OUT char** out,
                     _WS_OUT uint64_t *out_size) {
    struct timeval tv;
    char mask[4];
	unsigned int mask_int;
	unsigned int payload_len_bits;
	unsigned int payload_bit_offset = 6;
    unsigned int extend_payload_len_bits, i;
	unsigned long long frame_size;
    const int MASK_BIT_LEN = 4;
    gettimeofday(&tv, NULL);
	srand(tv.tv_usec * tv.tv_sec);
	mask_int = rand();
	memcpy(mask, &mask_int, 4);
    /**
     * payload_len bits
     * ref to https://tools.ietf.org/html/rfc6455#section-5.2
     * If 0-125, that is the payload length
     *
     * If payload length is equals 126, the following 2 bytes interpreted as a
     * 16-bit unsigned integer are the payload length
     * 
     * If 127, the following 8 bytes interpreted as a 64-bit unsigned integer (the
     * most significant bit MUST be 0) are the payload length.
     */
	if (payload_size <= 125) {
        // consts of ((fin + rsv1/2/3 + opcode) + payload-len bits + mask bit len + payload len)
        extend_payload_len_bits = 0;
		frame_size = 1 + 1 + MASK_BIT_LEN + payload_size;
        payload_len_bits = payload_size;
	} else if (payload_size > 125 && payload_size <= 0xffff) {
        extend_payload_len_bits = 2;
        // consts of ((fin + rsv1/2/3 + opcode) + payload-len bits + extend-payload-len bites + mask bit len + payload len)
		frame_size = 1 + 1 + extend_payload_len_bits + MASK_BIT_LEN + payload_size;
		payload_len_bits = 126;
		payload_bit_offset += extend_payload_len_bits;
	} else if (payload_size > 0xffff && payload_size <= 0xffffffffffffffffLL) {
        extend_payload_len_bits = 8;
        // consts of ((fin + rsv1/2/3 + opcode) + payload-len bits + extend-payload-len bites + mask bit len + payload len)
		frame_size = 1 + 1 + extend_payload_len_bits + MASK_BIT_LEN + payload_size;
		payload_len_bits = 127;
		payload_bit_offset += extend_payload_len_bits;
	} else {
        if (ws->error_cb) {
            ws_error_t *err = ws_new_error(WS_SEND_DATA_TOO_LARGE_ERR);
            ws->error_cb(ws, err);
            free(err);
        }
		return ;
	}
    *out_size = frame_size;
	char *data = (*out) = (char *)malloc(frame_size);
    char *buf_offset = data;
    bzero(data, frame_size);
	*data = flags & 0xff;
    buf_offset = data + 1;
    // set mask bit = 1
	*(buf_offset) = payload_len_bits | 0x80; //payload length with mask bit on
    buf_offset = data + 2;
	if (payload_len_bits == 126) {
		payload_size &= 0xffff;
	} else if (payload_len_bits == 127) {
		payload_size &= 0xffffffffffffffffLL;
	}
    for (i = 0; i < extend_payload_len_bits; i++) {
        *(buf_offset + i) = *((char *)&payload_size + (extend_payload_len_bits - i - 1));
    }

    /**
     * according to https://tools.ietf.org/html/rfc6455#section-5.3
     * 
     * buf_offset is set to mask bit
     */
    buf_offset = data + payload_bit_offset - 4;
	for (i = 0; i < 4; i++) {
		*(buf_offset + i) = mask[i] & 0xff;
    }
    /**
     * mask the payload data 
     */
    buf_offset = data + payload_bit_offset;
	memcpy(buf_offset, payload, payload_size);
	mask_payload(mask, buf_offset, payload_size);
}

總結

對WebSocket的學習主要是對協議的理解,理解了協議,上面複雜的程式碼自然而然就會明白~

相關文章