Python 中 lru_cache 的使用和實現

zikcheng發表於2021-01-24

在計算機軟體領域,快取(Cache)指的是將部分資料儲存在記憶體中,以便下次能夠更快地訪問這些資料,這也是一個典型的用空間換時間的例子。一般用於快取的記憶體空間是固定的,當有更多的資料需要快取的時候,需要將已快取的部分資料清除後再將新的快取資料放進去。需要清除哪些資料,就涉及到了快取置換的策略,LRU(Least Recently Used,最近最少使用)是很常見的一個,也是 Python 中提供的快取置換策略。

下面我們通過一個簡單的示例來看 Python 中的 lru_cache 是如何使用的。

def factorial(n):
    print(f"計算 {n} 的階乘")
    return 1 if n <= 1 else n * factorial(n - 1)

a = factorial(5)
print(f'5! = {a}')
b = factorial(3)
print(f'3! = {b}')

上面的程式碼中定義了函式 factorial,通過遞迴的方式計算 n 的階乘,並且在函式呼叫的時候列印出 n 的值。然後分別計算 5 和 3 的階乘,並列印結果。執行上面的程式碼,輸出如下

計算 5 的階乘
計算 4 的階乘
計算 3 的階乘
計算 2 的階乘
計算 1 的階乘
5! = 120
計算 3 的階乘
計算 2 的階乘
計算 1 的階乘
3! = 6

可以看到,factorial(3) 的結果在計算 factorial(5) 的時候已經被計算過了,但是後面又被重複計算了。為了避免這種重複計算,我們可以在定義函式 factorial 的時候加上 lru_cache 裝飾器,如下所示

import functools
# 注意 lru_cache 後的一對括號,證明這是帶引數的裝飾器
@functools.lru_cache()
def factorial(n):
    print(f"計算 {n} 的階乘")
    return 1 if n <= 1 else n * factorial(n - 1)

重新執行程式碼,輸入如下

計算 5 的階乘
計算 4 的階乘
計算 3 的階乘
計算 2 的階乘
計算 1 的階乘
5! = 120
3! = 6

可以看到,這次在呼叫 factorial(3) 的時候沒有列印相應的輸出,也就是說 factorial(3) 是直接從快取讀取的結果,證明快取生效了。

被 lru_cache 修飾的函式在被相同引數呼叫的時候,後續的呼叫都是直接從快取讀結果,而不用真正執行函式。下面我們深入原始碼,看看 Python 內部是怎麼實現 lru_cache 的。寫作時 Python 最新發行版是 3.9,所以這裡使用的是 Python 3.9 的原始碼,並且保留了原始碼中的註釋。

def lru_cache(maxsize=128, typed=False):
    """Least-recently-used cache decorator.
    If *maxsize* is set to None, the LRU features are disabled and the cache
    can grow without bound.
    If *typed* is True, arguments of different types will be cached separately.
    For example, f(3.0) and f(3) will be treated as distinct calls with
    distinct results.
    Arguments to the cached function must be hashable.
    View the cache statistics named tuple (hits, misses, maxsize, currsize)
    with f.cache_info().  Clear the cache and statistics with f.cache_clear().
    Access the underlying function with f.__wrapped__.
    See:  http://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
    """

    # Users should only access the lru_cache through its public API:
    #       cache_info, cache_clear, and f.__wrapped__
    # The internals of the lru_cache are encapsulated for thread safety and
    # to allow the implementation to change (including a possible C version).
    
    if isinstance(maxsize, int):
        # Negative maxsize is treated as 0
        if maxsize < 0:
            maxsize = 0
    elif callable(maxsize) and isinstance(typed, bool):
        # The user_function was passed in directly via the maxsize argument
        user_function, maxsize = maxsize, 128
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
        return update_wrapper(wrapper, user_function)
    elif maxsize is not None:
        raise TypeError(
            'Expected first argument to be an integer, a callable, or None')
    
    def decorating_function(user_function):
        wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
        wrapper.cache_parameters = lambda : {'maxsize': maxsize, 'typed': typed}
        return update_wrapper(wrapper, user_function)
    
    return decorating_function

這段程式碼中有如下幾個關鍵點

  • 關鍵字引數

    maxsize 表示快取容量,如果為 None 表示容量不設限, typed 表示是否區分引數型別,註釋中也給出瞭解釋,如果 typed == True,那麼 f(3)f(3.0) 會被認為是不同的函式呼叫。

  • 第 24 行的條件分支

    如果 lru_cache 的第一個引數是可呼叫的,直接返回 wrapper,也就是把 lru_cache 當做不帶引數的裝飾器,這是 Python 3.8 才有的特性,也就是說在 Python 3.8 及之後的版本中我們可以用下面的方式使用 lru_cache,可能是為了防止程式設計師在使用 lru_cache 的時候忘記加括號。

    import functools
    # 注意 lru_cache 後面沒有括號,
    # 證明這是將其當做不帶引數的裝飾器
    @functools.lru_cache
    def factorial(n):
        print(f"計算 {n} 的階乘")
        return 1 if n <= 1 else n * factorial(n - 1)
    

    注意,Python 3.8 之前的版本執行上面程式碼會報錯:TypeError: Expected maxsize to be an integer or None。

lru_cache 的具體邏輯是在 _lru_cache_wrapper 函式中實現的,還是一樣,列出原始碼,保留註釋。

def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
    # Constants shared by all lru cache instances:
    sentinel = object()          # unique object used to signal cache misses
    make_key = _make_key         # build a key from the function arguments
    PREV, NEXT, KEY, RESULT = 0, 1, 2, 3   # names for the link fields

    cache = {}
    hits = misses = 0
    full = False
    cache_get = cache.get    # bound method to lookup a key or return None
    cache_len = cache.__len__  # get cache size without calling len()
    lock = RLock()           # because linkedlist updates aren't threadsafe
    root = []                # root of the circular doubly linked list
    root[:] = [root, root, None, None]     # initialize by pointing to self

    if maxsize == 0:

        def wrapper(*args, **kwds):
            # No caching -- just a statistics update
            nonlocal misses
            misses += 1
            result = user_function(*args, **kwds)
            return result

    elif maxsize is None:

        def wrapper(*args, **kwds):
            # Simple caching without ordering or size limit
            nonlocal hits, misses
            key = make_key(args, kwds, typed)
            result = cache_get(key, sentinel)
            if result is not sentinel:
                hits += 1
                return result
            misses += 1
            result = user_function(*args, **kwds)
            cache[key] = result
            return result

    else:

        def wrapper(*args, **kwds):
            # Size limited caching that tracks accesses by recency
            nonlocal root, hits, misses, full
            key = make_key(args, kwds, typed)
            with lock:
                link = cache_get(key)
                if link is not None:
                    # Move the link to the front of the circular queue
                    link_prev, link_next, _key, result = link
                    link_prev[NEXT] = link_next
                    link_next[PREV] = link_prev
                    last = root[PREV]
                    last[NEXT] = root[PREV] = link
                    link[PREV] = last
                    link[NEXT] = root
                    hits += 1
                    return result
                misses += 1
            result = user_function(*args, **kwds)
            with lock:
                if key in cache:
                    # Getting here means that this same key was added to the
                    # cache while the lock was released.  Since the link
                    # update is already done, we need only return the
                    # computed result and update the count of misses.
                    pass
                elif full:
                    # Use the old root to store the new key and result.
                    oldroot = root
                    oldroot[KEY] = key
                    oldroot[RESULT] = result
                    # Empty the oldest link and make it the new root.
                    # Keep a reference to the old key and old result to
                    # prevent their ref counts from going to zero during the
                    # update. That will prevent potentially arbitrary object
                    # clean-up code (i.e. __del__) from running while we're
                    # still adjusting the links.
                    root = oldroot[NEXT]
                    oldkey = root[KEY]
                    oldresult = root[RESULT]
                    root[KEY] = root[RESULT] = None
                    # Now update the cache dictionary.
                    del cache[oldkey]
                    # Save the potentially reentrant cache[key] assignment
                    # for last, after the root and links have been put in
                    # a consistent state.
                    cache[key] = oldroot
                else:
                    # Put result in a new link at the front of the queue.
                    last = root[PREV]
                    link = [last, root, key, result]
                    last[NEXT] = root[PREV] = cache[key] = link
                    # Use the cache_len bound method instead of the len() function
                    # which could potentially be wrapped in an lru_cache itself.
                    full = (cache_len() >= maxsize)
            return result

    def cache_info():
        """Report cache statistics"""
        with lock:
            return _CacheInfo(hits, misses, maxsize, cache_len())

    def cache_clear():
        """Clear the cache and cache statistics"""
        nonlocal hits, misses, full
        with lock:
            cache.clear()
            root[:] = [root, root, None, None]
            hits = misses = 0
            full = False

    wrapper.cache_info = cache_info
    wrapper.cache_clear = cache_clear
    return wrapper

函式開始的地方 2~14 行定義了一些關鍵變數,

  • hitsmisses 分別表示快取命中和沒有命中的次數
  • root 雙向迴圈連結串列的頭結點,每個節點儲存前向指標、後向指標、key 和 key 對應的 result,其中 key 為 _make_key 函式根據引數結算出來的字串,result 為被修飾的函式在給定的引數下返回的結果。注意,root 是不儲存資料 key 和 result 的。
  • cache 是真正儲存快取資料的地方,型別為 dict。cache 中的 key 也是 _make_key 函式根據引數結算出來的字串,value 儲存的是 key 對應的雙向迴圈連結串列中的節點。

接下來根據 maxsize 不同,定義不同的 wrapper

  • maxsize == 0,其實也就是沒有快取,那麼每次函式呼叫都不會命中,並且沒有命中的次數 misses 加 1。

  • maxsize is None,不限制快取大小,如果函式呼叫不命中,將沒有命中次數 misses 加 1,否則將命中次數 hits 加 1。

  • 限制快取的大小,那麼需要根據 LRU 演算法來更新 cache,也就是 42~97 行的程式碼。

    • 如果快取命中 key,那麼將命中節點移到雙向迴圈連結串列的結尾,並且返回結果(47~58 行)

      這裡通過字典加雙向迴圈連結串列的組合資料結構,實現了用 O(1) 的時間複雜度刪除給定的節點。

    • 如果沒有命中,並且快取滿了,那麼需要將最久沒有使用的節點(root 的下一個節點)刪除,並且將新的節點新增到連結串列結尾。在實現中有一個優化,直接將當前的 root 的 key 和 result 替換成新的值,將 root 的下一個節點置為新的 root,這樣得到的雙向迴圈連結串列結構跟刪除 root 的下一個節點並且將新節點加到連結串列結尾是一樣的,但是避免了刪除和新增節點的操作(68~88 行)

    • 如果沒有命中,並且快取沒滿,那麼直接將新節點新增到雙向迴圈連結串列的結尾(root[PREV],這裡我認為是結尾,但是程式碼註釋中寫的是開頭)(89~96 行)

最後給 wrapper 新增兩個屬性函式 cache_infocache_clearcache_info 顯示當前快取的命中情況的統計資料,cache_clear 用於清空快取。對於上面階乘相關的程式碼,如果在最後執行 factorial.cache_info(),會輸出

CacheInfo(hits=1, misses=5, maxsize=128, currsize=5)

第一次執行 factorial(5) 的時候都沒命中,所以 misses = 5,第二次執行 factorial(3) 的時候,快取命中,所以 hits = 1。

最後需要說明的是,對於有多個關鍵字引數的函式,如果兩次呼叫函式關鍵字引數傳入的順序不同,會被認為是不同的呼叫,不會命中快取。另外,被 lru_cache 裝飾的函式不能包含可變型別引數如 list,因為它們不支援 hash。

總結一下,這篇文章首先簡介了一下快取的概念,然後展示了在 Python 中 lru_cache 的使用方法,最後通過原始碼分析了 Python 中 lru_cache 的實現細節。

相關文章