用Python編寫自己的微型Redis

Python魔法师發表於2024-03-12

building-a-simple-redis-server-with-python

前幾天我想到,寫一個簡單的東西會很整潔 雷迪斯-像資料庫伺服器。雖然我有很多 WSGI應用程式的經驗,資料庫伺服器展示了一種新穎 挑戰,並被證明是學習如何工作的不錯的實際方法 Python中的套接字。在這篇文章中,我將分享我在此過程中學到的知識。

我專案的目的是 編寫一個簡單的伺服器 我可以用 我的任務佇列專案稱為 Huey。 Huey使用Redis作為預設儲存引擎來跟蹤被引用的工作, 完成的工作和其他結果。就本職位而言, 我進一步縮小了原始專案的範圍,以免造成混亂 使用程式碼的水域,您可以很容易地自己寫,但是如果您 很好奇,你可以看看 最終結果 這裡 (檔案)。

我們將要構建的伺服器將能夠響應以下命令:

  • GET <key>
  • SET <key> <value>
  • DELETE <key>
  • FLUSH
  • MGET <key1> ... <keyn>
  • MSET <key1> <value1> ... <keyn> <valuen>

我們還將支援以下資料型別:

  • Strings and Binary Data
  • Numbers
  • NULL
  • Arrays (which may be nested)
  • Dictionaries (which may be nested)
  • Error messages

為了非同步處理多個客戶端,我們將使用 gevent, 但是您也可以使用標準庫的 SocketServer 模組與 要麼 ForkingMixinThreadingMixin

骨架

讓我們為伺服器設定一個框架。我們需要伺服器本身,以及 新客戶端連線時要執行的回撥。另外,我們需要 某種邏輯來處理客戶端請求併傳送響應。

這是一個開始:

from gevent import socket
from gevent.pool import Pool
from gevent.server import StreamServer

from collections import namedtuple
from io import BytesIO
from socket import error as socket_error


# We'll use exceptions to notify the connection-handling loop of problems.
class CommandError(Exception): pass
class Disconnect(Exception): pass

Error = namedtuple('Error', ('message',))


class ProtocolHandler(object):
    def handle_request(self, socket_file):
        # Parse a request from the client into it's component parts.
        pass

    def write_response(self, socket_file, data):
        # Serialize the response data and send it to the client.
        pass


class Server(object):
    def __init__(self, host='127.0.0.1', port=31337, max_clients=64):
        self._pool = Pool(max_clients)
        self._server = StreamServer(
            (host, port),
            self.connection_handler,
            spawn=self._pool)

        self._protocol = ProtocolHandler()
        self._kv = {}

    def connection_handler(self, conn, address):
        # Convert "conn" (a socket object) into a file-like object.
        socket_file = conn.makefile('rwb')

        # Process client requests until client disconnects.
        while True:
            try:
                data = self._protocol.handle_request(socket_file)
            except Disconnect:
                break

            try:
                resp = self.get_response(data)
            except CommandError as exc:
                resp = Error(exc.args[0])

            self._protocol.write_response(socket_file, resp)

    def get_response(self, data):
        # Here we'll actually unpack the data sent by the client, execute the
        # command they specified, and pass back the return value.
        pass

    def run(self):
        self._server.serve_forever()

希望以上程式碼相當清楚。我們已經分開了擔憂,以便 協議處理屬於自己的類,有兩種公共方法: handle_requestwrite_response。伺服器本身使用協議 處理程式解壓縮客戶端請求並將伺服器響應序列化回 客戶。The get_response() 該方法將用於執行命令 由客戶發起。

仔細檢視程式碼 connection_handler() 方法,你可以 看到我們在套接字物件周圍獲得了類似檔案的包裝紙。這個包裝器 讓我們抽象一些 怪癖 通常會遇到使用原始插座的情況。函式輸入 無窮迴圈,讀取客戶端的請求,傳送響應,最後 客戶端斷開連線時退出迴圈(由 read() 返回 一個空字串)。

我們使用鍵入的異常來處理客戶端斷開連線並通知使用者 錯誤處理命令。例如,如果使用者做錯了 對伺服器的格式化請求,我們將提出一個 CommandError, 哪個是 序列化為錯誤響應併傳送給客戶端。

在繼續之前,讓我們討論客戶端和伺服器將如何通訊。

執行緒

我面臨的第一個挑戰是如何處理透過 線。我在網上找到的大多數示例都是毫無意義的回聲伺服器,它們進行了轉換 套接字到類似檔案的物件,並且剛剛呼叫 readline()。如果我想 用新線儲存一些醃製的資料或字串,我需要一些 一種序列化格式。

在浪費時間嘗試發明合適的東西之後,我決定閱讀 有關文件 Redis協議, 其中 事實證明實施起來非常簡單,並且具有 支援幾種不同的資料型別。

Redis協議使用請求/響應通訊模式與 客戶。來自伺服器的響應將使用第一個位元組來指示 資料型別,然後是資料,以回車/線路進給終止。

資料型別.png

讓我們填寫協議處理程式的類,使其實現Redis 協議。

class ProtocolHandler(object):
    def __init__(self):
        self.handlers = {
            '+': self.handle_simple_string,
            '-': self.handle_error,
            ':': self.handle_integer,
            '$': self.handle_string,
            '*': self.handle_array,
            '%': self.handle_dict}

    def handle_request(self, socket_file):
        first_byte = socket_file.read(1)
        if not first_byte:
            raise Disconnect()

        try:
            # Delegate to the appropriate handler based on the first byte.
            return self.handlers[first_byte](socket_file)
        except KeyError:
            raise CommandError('bad request')

    def handle_simple_string(self, socket_file):
        return socket_file.readline().rstrip('\r\n')

    def handle_error(self, socket_file):
        return Error(socket_file.readline().rstrip('\r\n'))

    def handle_integer(self, socket_file):
        return int(socket_file.readline().rstrip('\r\n'))

    def handle_string(self, socket_file):
        # First read the length ($<length>\r\n).
        length = int(socket_file.readline().rstrip('\r\n'))
        if length == -1:
            return None  # Special-case for NULLs.
        length += 2  # Include the trailing \r\n in count.
        return socket_file.read(length)[:-2]

    def handle_array(self, socket_file):
        num_elements = int(socket_file.readline().rstrip('\r\n'))
        return [self.handle_request(socket_file) for _ in range(num_elements)]

    def handle_dict(self, socket_file):
        num_items = int(socket_file.readline().rstrip('\r\n'))
        elements = [self.handle_request(socket_file)
                    for _ in range(num_items * 2)]
        return dict(zip(elements[::2], elements[1::2]))

對於協議的序列化方面,我們將執行與上述相反的操作: 將Python物件轉換為序列化的物件!

class ProtocolHandler(object):
    # ... above methods omitted ...
    def write_response(self, socket_file, data):
        buf = BytesIO()
        self._write(buf, data)
        buf.seek(0)
        socket_file.write(buf.getvalue())
        socket_file.flush()

    def _write(self, buf, data):
        if isinstance(data, str):
            data = data.encode('utf-8')

        if isinstance(data, bytes):
            buf.write('$%s\r\n%s\r\n' % (len(data), data))
        elif isinstance(data, int):
            buf.write(':%s\r\n' % data)
        elif isinstance(data, Error):
            buf.write('-%s\r\n' % error.message)
        elif isinstance(data, (list, tuple)):
            buf.write('*%s\r\n' % len(data))
            for item in data:
                self._write(buf, item)
        elif isinstance(data, dict):
            buf.write('%%%s\r\n' % len(data))
            for key in data:
                self._write(buf, key)
                self._write(buf, data[key])
        elif data is None:
            buf.write('$-1\r\n')
        else:
            raise CommandError('unrecognized type: %s' % type(data))

將協議處理保持在其自己的類中的另一個好處是 我們可以重複使用 handle_requestwrite_response 建立方法 客戶端庫。

執行命令

Server 我們模擬的課程現在需要 get_response() 方法 已實施。命令將假定由客戶端以簡單方式傳送 字串或命令引數陣列,因此 data 傳遞給 get_response() 將是位元組或列表。為了簡化處理,如果 data 這是一個簡單的字串,我們將透過拆分將其轉換為列表 空格。

第一個引數將是命令名稱,並帶有任何其他引數 屬於指定命令。就像我們對第一個的對映一樣 位元組給處理者 ProtocolHandler, 讓我們建立一個的對映 命令回撥 Server:

class Server(object):
    def __init__(self, host='127.0.0.1', port=31337, max_clients=64):
        self._pool = Pool(max_clients)
        self._server = StreamServer(
            (host, port),
            self.connection_handler,
            spawn=self._pool)

        self._protocol = ProtocolHandler()
        self._kv = {}

        self._commands = self.get_commands()

    def get_commands(self):
        return {
            'GET': self.get,
            'SET': self.set,
            'DELETE': self.delete,
            'FLUSH': self.flush,
            'MGET': self.mget,
            'MSET': self.mset}

    def get_response(self, data):
        if not isinstance(data, list):
            try:
                data = data.split()
            except:
                raise CommandError('Request must be list or simple string.')

        if not data:
            raise CommandError('Missing command')

        command = data[0].upper()
        if command not in self._commands:
            raise CommandError('Unrecognized command: %s' % command)

        return self._commands[command](*data[1:])

我們的伺服器快完成了! 我們只需要執行六個命令 在 get_commands() 方法:

class Server(object):
    def get(self, key):
        return self._kv.get(key)

    def set(self, key, value):
        self._kv[key] = value
        return 1

    def delete(self, key):
        if key in self._kv:
            del self._kv[key]
            return 1
        return 0

    def flush(self):
        kvlen = len(self._kv)
        self._kv.clear()
        return kvlen

    def mget(self, *keys):
        return [self._kv.get(key) for key in keys]

    def mset(self, *items):
        data = zip(items[::2], items[1::2])
        for key, value in data:
            self._kv[key] = value
        return len(data)

而已! 我們的伺服器現在可以開始處理請求了。在下一個 本節,我們將實現一個客戶端與伺服器進行互動。

客戶端

要與伺服器互動,讓我們重新使用 ProtocolHandler 類到 實現一個簡單的客戶端。客戶端將連線到伺服器併傳送 命令編碼為列表。我們將同時使用 write_response()handle_request() 編碼請求和處理伺服器響應的邏輯 分別。

class Client(object):
    def __init__(self, host='127.0.0.1', port=31337):
        self._protocol = ProtocolHandler()
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._socket.connect((host, port))
        self._fh = self._socket.makefile('rwb')

    def execute(self, *args):
        self._protocol.write_response(self._fh, args)
        resp = self._protocol.handle_request(self._fh)
        if isinstance(resp, Error):
            raise CommandError(resp.message)
        return resp

execute() 方法上,我們可以傳遞任意引數列表,這些引數將被編碼為陣列併傳送到伺服器。來自伺服器的響應被解析並作為Python物件返回。為了方便起見,我們可以為各個命令編寫客戶端方法:

class Client(object):
    # ...
    def get(self, key):
        return self.execute('GET', key)

    def set(self, key, value):
        return self.execute('SET', key, value)

    def delete(self, key):
        return self.execute('DELETE', key)

    def flush(self):
        return self.execute('FLUSH')

    def mget(self, *keys):
        return self.execute('MGET', *keys)

    def mset(self, *items):
        return self.execute('MSET', *items)

為了測試我們的客戶端,讓我們配置Python指令碼以啟動伺服器 直接從命令列執行時:

測試伺服器

要測試伺服器,只需從命令列執行伺服器的Python模組即可。在另一個終端中,開啟Python直譯器並匯入 Client 來自伺服器模組的類。安裝客戶端將開啟連線,您可以開始執行命令!

>>> from server_ex import Client
>>> client = Client()
>>> client.mset('k1', 'v1', 'k2', ['v2-0', 1, 'v2-2'], 'k3', 'v3')
3
>>> client.get('k2')
['v2-0', 1, 'v2-2']
>>> client.mget('k3', 'k1')
['v3', 'v1']
>>> client.delete('k1')
1
>>> client.get('k1')
>>> client.delete('k1')
0
>>> client.set('kx', {'vx': {'vy': 0, 'vz': [1, 2, 3]}})
1
>>> client.get('kx')
{'vx': {'vy': 0, 'vz': [1, 2, 3]}}
>>> client.flush()
2

完整程式碼

from gevent import socket
from gevent.pool import Pool
from gevent.server import StreamServer

from collections import namedtuple
from io import BytesIO
from socket import error as socket_error
import logging


logger = logging.getLogger(__name__)


class CommandError(Exception): pass
class Disconnect(Exception): pass

Error = namedtuple('Error', ('message',))


class ProtocolHandler(object):
    def __init__(self):
        self.handlers = {
            '+': self.handle_simple_string,
            '-': self.handle_error,
            ':': self.handle_integer,
            '$': self.handle_string,
            '*': self.handle_array,
            '%': self.handle_dict}

    def handle_request(self, socket_file):
        first_byte = socket_file.read(1)
        if not first_byte:
            raise Disconnect()

        try:
            # Delegate to the appropriate handler based on the first byte.
            return self.handlers[first_byte](socket_file)
        except KeyError:
            raise CommandError('bad request')

    def handle_simple_string(self, socket_file):
        return socket_file.readline().rstrip('\r\n')

    def handle_error(self, socket_file):
        return Error(socket_file.readline().rstrip('\r\n'))

    def handle_integer(self, socket_file):
        return int(socket_file.readline().rstrip('\r\n'))

    def handle_string(self, socket_file):
        # First read the length ($<length>\r\n).
        length = int(socket_file.readline().rstrip('\r\n'))
        if length == -1:
            return None  # Special-case for NULLs.
        length += 2  # Include the trailing \r\n in count.
        return socket_file.read(length)[:-2]

    def handle_array(self, socket_file):
        num_elements = int(socket_file.readline().rstrip('\r\n'))
        return [self.handle_request(socket_file) for _ in range(num_elements)]
    
    def handle_dict(self, socket_file):
        num_items = int(socket_file.readline().rstrip('\r\n'))
        elements = [self.handle_request(socket_file)
                    for _ in range(num_items * 2)]
        return dict(zip(elements[::2], elements[1::2]))

    def write_response(self, socket_file, data):
        buf = BytesIO()
        self._write(buf, data)
        buf.seek(0)
        socket_file.write(buf.getvalue())
        socket_file.flush()

    def _write(self, buf, data):
        if isinstance(data, str):
            data = data.encode('utf-8')

        if isinstance(data, bytes):
            buf.write('$%s\r\n%s\r\n' % (len(data), data))
        elif isinstance(data, int):
            buf.write(':%s\r\n' % data)
        elif isinstance(data, Error):
            buf.write('-%s\r\n' % error.message)
        elif isinstance(data, (list, tuple)):
            buf.write('*%s\r\n' % len(data))
            for item in data:
                self._write(buf, item)
        elif isinstance(data, dict):
            buf.write('%%%s\r\n' % len(data))
            for key in data:
                self._write(buf, key)
                self._write(buf, data[key])
        elif data is None:
            buf.write('$-1\r\n')
        else:
            raise CommandError('unrecognized type: %s' % type(data))


class Server(object):
    def __init__(self, host='127.0.0.1', port=31337, max_clients=64):
        self._pool = Pool(max_clients)
        self._server = StreamServer(
            (host, port),
            self.connection_handler,
            spawn=self._pool)

        self._protocol = ProtocolHandler()
        self._kv = {}

        self._commands = self.get_commands()

    def get_commands(self):
        return {
            'GET': self.get,
            'SET': self.set,
            'DELETE': self.delete,
            'FLUSH': self.flush,
            'MGET': self.mget,
            'MSET': self.mset}

    def connection_handler(self, conn, address):
        logger.info('Connection received: %s:%s' % address)
        # Convert "conn" (a socket object) into a file-like object.
        socket_file = conn.makefile('rwb')

        # Process client requests until client disconnects.
        while True:
            try:
                data = self._protocol.handle_request(socket_file)
            except Disconnect:
                logger.info('Client went away: %s:%s' % address)
                break

            try:
                resp = self.get_response(data)
            except CommandError as exc:
                logger.exception('Command error')
                resp = Error(exc.args[0])

            self._protocol.write_response(socket_file, resp)

    def run(self):
        self._server.serve_forever()

    def get_response(self, data):
        if not isinstance(data, list):
            try:
                data = data.split()
            except:
                raise CommandError('Request must be list or simple string.')

        if not data:
            raise CommandError('Missing command')

        command = data[0].upper()
        if command not in self._commands:
            raise CommandError('Unrecognized command: %s' % command)
        else:
            logger.debug('Received %s', command)

        return self._commands[command](*data[1:])

    def get(self, key):
        return self._kv.get(key)

    def set(self, key, value):
        self._kv[key] = value
        return 1

    def delete(self, key):
        if key in self._kv:
            del self._kv[key]
            return 1
        return 0

    def flush(self):
        kvlen = len(self._kv)
        self._kv.clear()
        return kvlen

    def mget(self, *keys):
        return [self._kv.get(key) for key in keys]

    def mset(self, *items):
        data = zip(items[::2], items[1::2])
        for key, value in data:
            self._kv[key] = value
        return len(data)


class Client(object):
    def __init__(self, host='127.0.0.1', port=31337):
        self._protocol = ProtocolHandler()
        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._socket.connect((host, port))
        self._fh = self._socket.makefile('rwb')

    def execute(self, *args):
        self._protocol.write_response(self._fh, args)
        resp = self._protocol.handle_request(self._fh)
        if isinstance(resp, Error):
            raise CommandError(resp.message)
        return resp

    def get(self, key):
        return self.execute('GET', key)

    def set(self, key, value):
        return self.execute('SET', key, value)

    def delete(self, key):
        return self.execute('DELETE', key)

    def flush(self):
        return self.execute('FLUSH')

    def mget(self, *keys):
        return self.execute('MGET', *keys)

    def mset(self, *items):
        return self.execute('MSET', *items)


if __name__ == '__main__':
    from gevent import monkey; monkey.patch_all()
    logger.addHandler(logging.StreamHandler())
    logger.setLevel(logging.DEBUG)
    Server().run()

各位看官,如對你有幫助歡迎點贊,收藏,轉發,關注公眾號【Python魔法師】獲取更多Python魔法知識~

qrcode.jpg

相關文章