Treevalue(0x02)——函式樹化詳細解析(上篇)

HansBug發表於2021-11-05

本文將對 func_treelize 這一treevalue庫中的核心功能進行詳細的原理解析。

關於treevalue的概述,可以參考之前的文章:Treevalue(0x01)——功能概述

樹化函式基本原理

在treevalue庫中, func_treelize 是核心特性之一,可以將普通的函式快速作用於樹物件上。而這一“作用”的原理是什麼呢,我們來一起看看——首先準備一個普通的函式,並加上 func_treelize 裝飾器,就像這樣

from treevalue import func_treelize


@func_treelize()
def gcd(a, b):  # GCD calculation
    print('gcd', a, b)
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

函式的部分是一個最大公因數的計算,並且和之前文章(Treevalue(0x01)——功能概述)中的區別在於,新增了一行 print 輸出,用於體現函式內部在整個計算過程中是如何被呼叫的。基於這一函式,我們進行如下的呼叫,可以得到對應的輸出結果

from treevalue import FastTreeValue

gcd(9, 12)
# gcd 9 12
# 3

t1 = FastTreeValue({'a': 2, 'b': 30, 'x': {'c': 4, 'd': 9}})
t2 = FastTreeValue({'a': 4, 'b': 48, 'x': {'c': 6, 'd': 54}})
gcd(t1, t2)
# gcd 30 48
# gcd 9 54
# gcd 4 6
# gcd 2 4
# <TreeValue 0x7f12950e3be0>
# ├── a --> 2
# ├── b --> 6
# └── x --> <TreeValue 0x7f1296732310>
#     ├── c --> 2
#     └── d --> 9

根據輸出語句,不難發現——經過func_treelize裝飾後的函式,在被傳入TreeValue型別的時候,會自動基於其結構將內部的數值一一對應傳入原函式,並在執行計算後組裝成與原來相同的樹結構
基於以上基本特性,func_treelize這一過程也被稱為函式的樹化,經過樹化後的函式將滿足以下基本特性:

  1. 當所有傳入引數均為非樹物件時,函式行為與返回值與原函式保持嚴格一致,即樹化後的函式依然可以像原函式一樣地使用
  2. 樹化的函式本身不會對傳入的樹物件內部結構有顯式的限制,在函式的樹化邏輯中將基於傳入樹引數的結構生成最終的返回值結構。
  3. 函式的樹化邏輯部分不會對樹物件內部的值進行任何的判定與檢測,只是作為一箇中繼器將對應的值傳入原函式並獲取運算結果

樹化函式執行機制

通過開頭章節的簡單例子展示,相信各位已經對函式的樹化有了基本的概念和了解。在本章中,將對函式的樹化過程進行更加詳細的機制分析。

機制概述

在開頭章節的例子中,展現的只是兩種最為理想化的情況:

  1. 傳入的引數均為非樹物件
  2. 傳入的引數均為結構完全一致的樹物件

然而實際上,基於對“樹”這一資料結構的基本瞭解,不難發現實際上需要作出處理的情況依然有很多,包括但不限於:

  • 鍵值缺少——參與計算的某個樹物件在對應的位置上缺少了對應的鍵值,這樣的情況如何處理?例如下圖中, t2.x.d 缺失,這樣的情況該如何處理?

  • 鍵值型別不匹配——參與計算的某幾個樹物件對應位置上,有些是葉子節點值,有些是非葉子節點子樹,形成“值-子樹”之間的直接運算,這樣的情況如何定義?例如下圖中, t1.b 為子樹但是 t2.b 為值,這樣的情況如何定義?

  • 計算模式多樣性——當參與計算的樹物件之間的結構存在較多較大差異性時,如何設計計算策略使之能支援更多樣化的計算?例如下列的場景,如何組織對如此結構各異的樹之間的運算?

  • 資料格式多樣性——當參與計算的葉子節點值格式存在不統一時,如何處理?例如下面的場景,如何對 t1t2 下顯然不同尺寸的 torch.Tensor 進行處理?

因此,基於這些很現實的問題,我們為樹化函式定義瞭如下的選項:

  • 模式選項(mode)——決定樹化函式的整體執行機制。
  • 繼承選項(inherit)——對鍵值型別不匹配的情況進行了定義,並提供了處理機制。
  • 預設選項(missing)——為鍵值缺少的情況提供了預設值補全機制。

模式選項(mode)

模式選項是樹化函式中最為重要的選項,其將直接決定樹化函式的主體計算邏輯。目前定義了四種常用模式:

  • 嚴格模式(STRICT)
  • 內共同模式(INNER)
  • 外共有模式(OUTER)
  • 左優先模式(LEFT)

接下來的子章節中會結合例子進行逐一介紹。

嚴格模式(STRICT)

嚴格模式是最常用的模式選項,意味著當且僅當所有樹引數在當前子樹位置上的鍵一一對應時,會將其鍵值進行一一對應地代入計算,否則丟擲異常。程式碼實現如下,與開頭的例子等價,模式選項的預設值即為嚴格模式

from treevalue import func_treelize


@func_treelize(mode='strict')
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

在上述的樹化gcd函式中,完整的計算機制如下圖1所示, tr 為樹化gcd的運算結果


(圖1,t1、t2內的鍵值可以形成一一對應)

但是當出現如下所示的引數時,則應丟擲異常,因為部分鍵存在缺失,無法形成一一對應。


(圖2,t1.b與t1.x.c缺失,無法形成一一對應)

嚴格模式是一種最為常見的計算邏輯,適用於大部分常見情況,也是在業務邏輯上最為順理成章的一種模式。但是對非規則結構下的計算則不能相容,因此另外三種模式選項分別針對不同的情況來支援非規則結構下的計算。

內共同模式(INNER)

內共同模式下,僅會對全部樹引數當前子樹位置上均存在此鍵時,才會對將其鍵值進行一一對應地代入計算,而當此鍵值在某一樹引數當前子樹位置上存在缺失情況是,則會直接忽略該組鍵值。程式碼實現如下,將 mode 設定為 inner 即可

from treevalue import func_treelize


@func_treelize(mode='inner')
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

例如對圖2所示的例子,在內共同模式下可以正常計算,如圖3所示


(圖3,t1.x.c和t2.b因為t2.x.c和t1.b的缺失而被忽略)

內共同模式會忽略無法形成對應的多餘值,可以確保在幾乎所有情況下均能得出計算結果而不會產生錯誤。但是會不可避免地造成部分資訊丟失,而在一部分情況下這是不可接受的,因此請根據實際需求進行選擇。

外共有模式(OUTER)

外共有模式下,只要在任意一個樹引數的當前子樹位置上存在此鍵值,則會將其進行代入計算。而對於缺失的值,則會使用預設選項中設定的值或生成器進行獲取並代入。程式碼實現如下,將 mode 設定為 outer 即可,並將預設選項設定為值 1

from treevalue import func_treelize


@func_treelize(mode='outer', missing=1)
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

例如對圖2所示的例子,在外共有模式下可以正常計算,如圖4所示


(圖4,t1.b和t1.x.c缺失,將使用預設選項指定的預設值1)

外共有模式將會讓所有的數值參與運算,但是在絕大部分情況下均依賴預設選項的設定,因此在使用前請確保預設選項的正確配置,以及業務邏輯上的自洽。

左優先模式(LEFT)

左優先模式下,參與運算的鍵值將以全部樹引數中最左的一項為參考。其中最左的一項定義為,在python函式呼叫的位置引數(postional argument)中,如果存在樹引數,則取最左的一項;如果不存在,則在函式呼叫的鍵值引數(key-word argument)紅,取字典序最小的一項。程式碼實現如下,將 mode 設定為 left 即可,並將預設選項設定為值 1

from treevalue import func_treelize


@func_treelize(mode='left', missing=1)
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

例如對於圖2所示的 gcd(t1, t2) 例子中,在左優先模式下計算結果如下,如圖5所示


(圖5,t2.b因t1.b的缺失而被忽略,而t2.x.c取預設值1)

而在 gcd(t2, t1) 例子中,左優先計算結果如下,如圖6所示


(圖6,t1.x.c因t2.x.c的缺失而被忽略,而t1.b取預設值1)

左優先模式會按照最左樹引數的結構來進行計算,生成的計算結果也將和最左的引數保持一致。但是與外共有模式類似,左優先模式在絕大部分情況下依賴預設選項的配置,需要確保配置準確無誤且自洽。此外,對於原本滿足交換律的運算,經過左優先模式的樹化後將會失去原有的交換律性質,這一點請務必留意。

繼承選項(inherit)

繼承選項可以通過普通值的繼承機制,讓樹化函式在實際應用中使用起來更加簡潔,也讓樹引數可以和普通引數在樹化後的函式中被混用。在預設情況下,繼承選項是處於開啟狀態的,即等價於如下的程式碼

from treevalue import func_treelize


@func_treelize(inherit=True)
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

因此,有如下的例子 gcd(t1, t2) ,其計算結果如圖7所示


(圖7,t2.x.c和t2.x.d繼承t2.x的值6)

此外顯而易見的是,也可以直接將非樹值直接傳入,和樹引數混用,例如下面的例子 gcd(100, t1) ,其計算結果如圖x所示


(圖8,值100被完全繼承並作為第一棵樹的全部值)

而當繼承選項被關閉時,則上述兩個例子均會丟擲異常,因為存在值和子樹混用的情況。

從業務邏輯的角度來看,繼承選項可以良好地適應大部分真實存在的值複用情況,且值和子樹混用在大多數業務邏輯上也是有明確意義的。但是當混用在業務邏輯角度上意義不明且需要被顯式地檢測時,則建議關閉繼承選項

預設選項(missing)

預設選項可以為部分鍵值存在缺失的情況提供一個值的補充,主要作用於外共有模式和左優先模式。我們可以通過 missing 引數直接提供值,如下所示

from treevalue import func_treelize, FastTreeValue

@func_treelize(mode='outer', missing=0)
def total(*args):
    return sum(args)

上述的加法函式計算例子如下, total(t1, t2, t3) 計算結果如下圖9所示


(圖9,預設值0被全面用於填補空缺,並最終計算出了有效的總和)

此外考慮到有些情況下,直接使用值作為預設值可能會存在公用同一個物件導致錯誤的情況,因此我們提供了通過傳入生成函式來產生預設值的用法。可以通過 missing 引數傳入值生成器,如下所示

from treevalue import func_treelize, FastTreeValue

@func_treelize(mode='outer', missing=lambda: [])
def append(arr: list, *args):
    for item in args:
        if item:
            arr.append(item)
    return arr

上述的列表追加值計算例子如下, append(t0, t1, t2, t3) 運算結果如下圖10所示


(圖10,每次預設均會生成新的空列表)

通過預設選項的有效配置,結合外共有模式和左優先模式,可以有效擴充套件樹化函式對值預設情況的處理能力。不過值得注意的是,預設選項在嚴格模式下無法生效,因為當檢測到鍵缺失時將會直接丟擲異常;以及預設模式在內共同模式下永遠無法實質上生效,因此樹化函式會針對這一情況丟擲一個警告資訊。

上升、下沉選項

除了上述的基本機制選項之外,樹化函式還提供了上升(rise)和下沉(subside)選項,以簡化對結構化資料的處理。兩者的功能分別為:

  • 下沉(subside)——嘗試將引數中頂層結構非樹的物件,提取結構後將結構下沉至樹內,使原函式在執行過程中可以接收到。關於下沉函式的具體細節可以參考之前文章
  • 上升(rise)——嘗試從返回結果樹的葉子節點值中提取共同結構,向上升至樹外,使返回值的邏輯結構可以被外部直接訪問。關於上升函式的具體細節可以參考之前文章

因此我們可以在需要的時候開啟這兩個選項,程式碼如下,實現的效果是從列表 arr 中查詢首個滿足條件值的位置( position ),並統計共有多少個滿足條件的值( cnt

from treevalue import func_treelize, FastTreeValue


@func_treelize(subside=True, rise=True)
def check(arr: list, target):
    position = None
    cnt = 0
    for i, item in enumerate(arr):
        if target(item):
            if position is None:
                position = i
            cnt += 1

    return position, cnt


t1 = FastTreeValue({'a': 2, 'b': 4, 'x': {'c': 7, 'd': 9}})
t2 = FastTreeValue({'a': 4, 'b': 48, 'x': {'c': 2, 'd': 53}})
t3 = FastTreeValue({'a': 9, 'b': -12, 'x': {'c': 3, 'd': 7}})

tr1, tr2 = check([t1, t2, t3], lambda x: x % 2 == 0)

程式碼中可以看到三棵樹 t1t2t3 可以直接用列表裝載,在原函式 check 中可以接收到對應位置上的值列表。並且由於 rise 選項的開啟,位置和數量所構成的二元組也會被提取出來,形成兩棵樹,即 tr1tr2 ,如下圖11所示


(圖11,[t1, t2, t3]作為列表引數,tr1, tr2作為返回值樹)

此外,上升和下沉選項一個更加有效的使用例子是對 torch.splittorch.stack 函式進行裝飾,程式碼如下所示

import torch

from treevalue import func_treelize, TreeValue

stack = func_treelize(subside=True)(torch.stack)
split = func_treelize(rise=True)(torch.split)

trees = [TreeValue({
    'a': torch.randn(2, 4),
    'b': torch.randn(3, 4),
    'x': {'c': torch.randn(2, 1, 3)}
}) for _ in range(10)]

st = stack(trees)  # stack all the trees together
splitted = split(st, [1] * 10)  # split back to trees

# splitted should be equal to trees

其中 st 即為合併後的樹,而 splitted 為再次拆分後的樹, splittedtrees 等價。

後續預告

本文主要針對treevalue的核心特性——樹化函式,基於其自身進行了詳細的原理解析,受限於篇幅,本次只著重講述了原生樹化函式本身的原理、特性以及例子。在下一篇中將會針對更多衍生場景進行分析與展示,敬請期待。

同時歡迎瞭解其他OpenDILab的開源專案:https://github.com/opendilab

相關文章