Treevalue(0x03)——函式樹化詳細解析(下篇)

HansBug發表於2021-11-29

好久不見,再一次回到 treevalue 系列。本文將基於上一篇treevalue講解,繼續對函式的樹化機制進行詳細解析,並且會更多的講述其衍生特性及應用。

樹化方法與類方法

首先,基於之前的樹化函式,我們可以對一般意義上的函式進行樹化擴充套件。而對“函式”這一範疇來說,其中自然也包含方法、類方法這兩種特殊的函式,它們在本質上和一般函式是類似的(關於這部分可以閱讀Python科普系列——類與方法(下篇)中“物件方法的本質”章節作進一步的瞭解)。也正是因為它們之間的相似性,所以無論是物件方法還是類方法,同樣都是可以被擴充套件的

基於上面所述的方法、類方法的性質,我們可以對其進行類似的樹化擴充套件。讓我們來看一個例子

from treevalue import TreeValue, method_treelize, classmethod_treelize


class MyTreeValue(TreeValue):
    @method_treelize()
    def plus(self, x):
        return self + x

    # with the usage of rise option, final return should be a tuple of 2 trees
    @classmethod
    @classmethod_treelize(rise=True)
    def add_all(cls, a, b):
        return cls, a + b

由此,我們構建了一個屬於自己的TreeValue類—— MyTreeValue 類,並且可以使用內部的方法與類方法進行物件導向程式的編寫。例如對於 MyTreeValue 類,我們可以執行以下的運算(程式碼接上文)

t1 = MyTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
t2 = MyTreeValue({'a': 5, 'b': 6, 'x': {'c': 7, 'd': 8}})

print(t1.plus(2))
# <MyTreeValue 0x7fe023375ee0>
# ├── a --> 3
# ├── b --> 4
# └── x --> <MyTreeValue 0x7fe023375eb0>
#     ├── c --> 5
#     └── d --> 6

print(t1.plus(t2))
# <MyTreeValue 0x7fe023375eb0>
# ├── a --> 6
# ├── b --> 8
# └── x --> <MyTreeValue 0x7fe021dd16a0>
#     ├── c --> 10
#     └── d --> 12

print(MyTreeValue.add_all(t1, t2))
# (<MyTreeValue 0x7effa62c6250>
# ├── a --> <class '__main__.MyTreeValue'>
# ├── b --> <class '__main__.MyTreeValue'>
# └── x --> <MyTreeValue 0x7effa62a0790>
#     ├── c --> <class '__main__.MyTreeValue'>
#     └── d --> <class '__main__.MyTreeValue'>
# , <MyTreeValue 0x7effa629df70>
# ├── a --> 6
# ├── b --> 8
# └── x --> <MyTreeValue 0x7effa62c6d90>
#     ├── c --> 10
#     └── d --> 12
# )

此外,對於物件方法,顯然存在一個運算主體,也就是 self ,並且常常會出現需要進行“原地運算”的情況,類似於torch庫裡面的sin_ 。在針對物件方法的樹化函式中,我們提供了 self_copy 選項,當開啟此選項時,計算完畢後會將各個節點上的執行結果掛載至當前的樹物件上,並將其作為返回值傳出。一個簡單的例子如下

from treevalue import TreeValue, method_treelize


class MyTreeValue(TreeValue):
    @method_treelize(self_copy=True)
    def plus_(self, x):
        return self + x


t1 = MyTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})

print(t1)
# <MyTreeValue 0x7f543c83cd60>
# ├── a --> 1
# ├── b --> 2
# └── x --> <MyTreeValue 0x7f543c83cd00>
#     ├── c --> 3
#     └── d --> 4

print(t1.plus_(2))
# <MyTreeValue 0x7f543c83cd60>
# ├── a --> 3
# ├── b --> 4
# └── x --> <MyTreeValue 0x7f543c83cd00>
#     ├── c --> 5
#     └── d --> 6

在上述程式碼中,可以看到 plus_ 方法的返回值仍是之前的樹物件,且內部的節點值均被替換為了計算結果值,此時如果訪問樹 t1 ,得到的也將會是這一物件。
延伸思考1:對於靜態方法,應該如何進行樹化?請通過編寫程式碼驗證你的猜想。

延伸思考2:對於屬性(property),僅考慮讀取( __get__ )功能的話,該如何進行樹化?請通過程式碼驗證你的猜想。

歡迎評論區討論!

樹化運算

如果你對算術運算的原理有所瞭解的話,應該知道在python中,算術運算也同樣是由一類特殊的物件方法支援的,例如加法運算是由 __add__ (self + x)、 __radd__ (x + self)和 __iadd__ (self += x)運算所共同支援的,而對運算子的過載也往往是通過此類魔術方法實現的。關於這部分,可以閱讀Python科普系列——類與方法(下篇)中“魔術方法的妙用”章節作更進一步的瞭解。

既然如此,不妨想一想,如果將樹化函式用在這類特殊的方法上,會產生什麼樣的奇妙效果呢?沒錯,如你所想,這類運算一樣是可以被擴充套件的,而效果就會像是如下的程式碼所示

from treevalue import TreeValue, method_treelize


class AddTreeValue(TreeValue):
    @method_treelize()
    def __add__(self, other):
        return self + other

    @method_treelize()
    def __radd__(self, other):
        return other + self

    @method_treelize(self_copy=True)
    def __iadd__(self, other):
        return self + other

執行起來的效果如下所示

t1 = AddTreeValue({'a': 1, 'x': {'c': 3}})
t2 = AddTreeValue({'a': 5, 'x': {'c': 7}})

print(t1)
# <AddTreeValue 0x7ff25d729e50>
# ├── a --> 1
# └── x --> <AddTreeValue 0x7ff25d729e20>
#     └── c --> 3
print(t1 + 2)
# <AddTreeValue 0x7ff25d72caf0>
# ├── a --> 3
# └── x --> <AddTreeValue 0x7ff25c17aa90>
#     └── c --> 5
print(3 + t1)
# <AddTreeValue 0x7ff25d72caf0>
# ├── a --> 4
# └── x --> <AddTreeValue 0x7ff25c17aa90>
#     └── c --> 6
print(t1 + t2)
# <AddTreeValue 0x7ff25d72caf0>
# ├── a --> 6
# └── x --> <AddTreeValue 0x7ff25c17aa90>
#     └── c --> 10

t1 += t2 + 10
print(t1)
# <AddTreeValue 0x7ff25d729e50>
# ├── a --> 16
# └── x --> <AddTreeValue 0x7ff25d729e20>
#     └── c --> 20

不僅如此,筆者作為treevalue的開發者也同樣是這麼想的。於是這裡提供了一個基於TreeValue,並提供了更多常用功能和運算,使之更加快捷易用的子類——FastTreeValue 。這個類從本系列的第一彈以來已經多次出場,在這裡我們終於得以揭曉其真正的奧祕。FastTreeValue類中,諸如上述的各類算術運算已經以類似的方式進行了實現,並可供使用。例如下面的這段程式碼

from treevalue import FastTreeValue

t1 = FastTreeValue({'a': 1, 'x': {'c': 3}})
t2 = FastTreeValue({'a': 5, 'x': {'c': 7}})

print(t1 * (1 - t1 + t2) % 10 + (t2 // t1))  # complex calculation
# <FastTreeValue 0x7f973be1eaf0>
# ├── a --> 10
# └── x --> <FastTreeValue 0x7f973be1ea00>
#     └── c --> 7

t3 = FastTreeValue({'a': 1, 'b': 'sdjkfh', 'x': {'c': [1, 2], 'd': 1.2}})
t4 = FastTreeValue({'a': 4, 'b': 'anstr', 'x': {'c': [4, 5, -2], 'd': -8.5}})

print(t3 + t4)  # add all together, not only int or float
# <FastTreeValue 0x7f973be1e970>
# ├── a --> 5
# ├── b --> 'sdjkfhanstr'
# └── x --> <FastTreeValue 0x7f973be1eac0>
#     ├── c --> [1, 2, 4, 5, -2]
#     └── d --> -7.3

t5 = FastTreeValue({'a': {2, 3}, 'x': {'c': 8937}})
t6 = FastTreeValue({'a': {1, 2, 4}, 'x': {'c': 910}})

print(t5 | t6)  # | and &, between sets and ints
# <FastTreeValue 0x7f973be1e640>
# ├── a --> {1, 2, 3, 4}
# └── x --> <FastTreeValue 0x7f973be1e8e0>
#     └── c --> 9199
print(t5 & t6)
# <FastTreeValue 0x7f973be1e640>
# ├── a --> {2}
# └── x --> <FastTreeValue 0x7f973be1e8e0>
#     └── c --> 648

至此,常規的算術運算已經被覆蓋,而且由於python對算術運算的支援方式,算術運算也並不受限於值的型別,而是可以廣泛地支援各種型別的運算。

延伸思考3:結合Python科普系列——類與方法(下篇)中“魔術方法的妙用”部分,想一想此類算術運算魔術方法各自應該被如何實現?然後去翻閱一下treevalue的原始碼驗證你的猜想。

歡迎評論區討論!

基於樹化運算的應用

實際上,python中以下劃線開頭和結尾的特殊運算並不只有上述的算術運算,還有一系列的操作類也一樣可以被用類似的方式擴充套件。其中最為典型的就是對於功能性的魔術方法所做的擴充套件,比如,我們可以對 __getitem____setitem__ 進行擴充套件,如下所示

from treevalue import TreeValue, method_treelize


class MyTreeValue(TreeValue):
    @method_treelize()
    def __getitem__(self, item):
        return self[item]

    @method_treelize()
    def __setitem__(self, key, value):
        self[key] = value

FastTreeValue 中也有類似的實現,由此可以產生的一個效果是這樣的,通過索引即可快速對下屬的所有物件進行訪問,程式碼如下

import torch

from treevalue import FastTreeValue

t1 = FastTreeValue({
    'a': torch.randn(2, 3),
    'x': {
        'c': torch.randn(3, 4),
    }
})

print(t1)
# <FastTreeValue 0x7f93f19b9c40>
# ├── a --> tensor([[-0.5878,  0.8615, -0.1703],
# │                 [ 1.5826, -0.5806,  1.5869]])
# └── x --> <FastTreeValue 0x7f93f19b9d00>
#     └── c --> tensor([[-0.3380, -0.6968,  0.7013, -0.8895],
#                       [-0.2798,  0.6196,  0.8141, -2.5651],
#                       [ 0.0113, -2.0468,  0.1121,  0.3606]])

print(t1[0])
# <FastTreeValue 0x7f93f19b9d30>
# ├── a --> tensor([-0.5878,  0.8615, -0.1703])
# └── x --> <FastTreeValue 0x7f93901c1fd0>
#     └── c --> tensor([-0.3380, -0.6968,  0.7013, -0.8895])
print(t1[:, 1:-1])
# <FastTreeValue 0x7f93f19b9d30>
# ├── a --> tensor([[ 0.8615],
# │                 [-0.5806]])
# └── x --> <FastTreeValue 0x7f93901c1fd0>
#     └── c --> tensor([[-0.6968,  0.7013],
#                       [ 0.6196,  0.8141],
#                       [-2.0468,  0.1121]])

除此之外,TreeValue類中,預留了一個_attr_extern方法,當嘗試獲取TreeValue物件包含的值時,一般通過直接訪問屬性實現,而噹噹前樹節點無此鍵時,則會進入_attr_extern方法。在原生的 TreeValue 類中,這一方法被實現為直接丟擲 KeyError 異常,而在 FastTreeValue 中進行了類似這樣的擴充套件(僅示意,與真實實現略有差異)

from treevalue import TreeValue, method_treelize


class MyTreeValue(TreeValue):
    @method_treelize()
    def _attr_extern(self, key):
        return getattr(self, key)

於是便可以實現類似這樣的效果

import torch

from treevalue import FastTreeValue

t1 = FastTreeValue({
    'a': torch.randn(2, 3),
    'x': {
        'c': torch.randn(3, 4),
    }
})

print(t1.shape)
# <FastTreeValue 0x7fac48ac66d0>
# ├── a --> torch.Size([2, 3])
# └── x --> <FastTreeValue 0x7fac48ac6700>
#     └── c --> torch.Size([3, 4])
print(t1.sin)
# <FastTreeValue 0x7f0fcd0e36a0>
# ├── a --> <built-in method sin of Tensor object at 0x7f0fcd0ea040>
# └── x --> <FastTreeValue 0x7f0fcd0e3df0>
#     └── c --> <built-in method sin of Tensor object at 0x7f0fcd0ea080>

可以看到,不僅一般意義上的屬性(例如 shape )可以被獲取並構建成樹,連物件的方法也被以同樣的方式進行了提取構造。這是因為在Python中,實際上屬性這一概念(更準確的說法是欄位,英文為Field)包含的內容有很多,其中包括方法(具體可以參考Python科普系列——類與方法(上篇)中“如何手動製造一個物件”章節作進一步瞭解),基於這一點,通過與上述程式碼類似的方式,我們可以獲得一棵由物件方法構成的樹,即如上述的 sin 方法一樣。

說到這裡,我們可以繼續去擴充套件一個魔術方法—— __call__ 方法,這個方法的作用是讓物件可以被以類似函式呼叫的方式直接執行。過載的方式如下所示

from treevalue import TreeValue, method_treelize


class MyTreeValue(TreeValue):
    @method_treelize()
    def __call__(self, *args, **kwargs):
        return self(*args, **kwargs)

FastTreeValue 中也作了類似的實現,因此上面獲取到的那棵由物件方法構成的樹,實際上是可以被執行的。而將對 _attr_extern__call__ 的擴充套件相結合,則可以形成這樣一種更為奇妙的用法——直接對樹物件執行其內部物件所包含的方法,如下所示

import torch

from treevalue import FastTreeValue

t1 = FastTreeValue({
    'a': torch.randn(2, 4),
    'x': {
        'c': torch.randn(3, 4),
    }
})

print(t1)
# <FastTreeValue 0x7f7e7534bc40>
# ├── a --> tensor([[ 1.4246,  0.4117, -1.1805,  0.1825],
# │                 [ 0.5865, -0.8895, -0.8055,  0.9112]])
# └── x --> <FastTreeValue 0x7f7e7534bd00>
#     └── c --> tensor([[ 1.6239e+00, -2.3074e+00, -2.8613e-01,  1.3310e+00],
#                       [-1.8917e-01,  1.6694e+00, -8.2944e-01,  2.8590e-01],
#                       [-4.0992e-01, -5.8827e-01,  2.0444e-03,  7.0647e-01]])

print(t1.sin())
# <FastTreeValue 0x7f7e7534bd30>
# ├── a --> tensor([[ 0.9893,  0.4002, -0.9248,  0.1814],
# │                 [ 0.5534, -0.7768, -0.7212,  0.7902]])
# └── x --> <FastTreeValue 0x7f7e7534bd60>
#     └── c --> tensor([[ 0.9986, -0.7407, -0.2822,  0.9714],
#                       [-0.1880,  0.9951, -0.7376,  0.2820],
#                       [-0.3985, -0.5549,  0.0020,  0.6491]])
print(t1.reshape((4, -1)))
# <FastTreeValue 0x7f7e13b43fa0>
# ├── a --> tensor([[ 1.4246,  0.4117],
# │                 [-1.1805,  0.1825],
# │                 [ 0.5865, -0.8895],
# │                 [-0.8055,  0.9112]])
# └── x --> <FastTreeValue 0x7f7e7534bd30>
#     └── c --> tensor([[ 1.6239e+00, -2.3074e+00, -2.8613e-01],
#                       [ 1.3310e+00, -1.8917e-01,  1.6694e+00],
#                       [-8.2944e-01,  2.8590e-01, -4.0992e-01],
#                       [-5.8827e-01,  2.0444e-03,  7.0647e-01]])

# different sizes
new_shapes = FastTreeValue({'a': (1, -1), 'x': {'c': (2, -1)}})
print(t1.reshape(new_shapes))
# <FastTreeValue 0x7f98d95241f0>
# ├── a --> tensor([[ 2.0423, -0.5339, -0.4458, -0.3386,  0.1002,  0.6809, -0.3839,  1.9945]])
# └── x --> <FastTreeValue 0x7f993b3e3d30>
#     └── c --> tensor([[ 0.9726,  0.2787,  1.2419, -0.4118,  2.2535, -0.7826],
#                       [-0.9467,  0.3230, -0.6319, -0.2424,  0.4348,  1.3872]])

可能讀者還會有些懵,這裡以上面的 reshape 為例,解釋一下其執行機理:

  • 首先,執行 t1.reshape ,進入已經被樹化的 _attr_extern 方法,獲取到一棵由方法物件組成的樹,設為 t1_m
  • 接下來,執行 t1_m((4, -1)) ,進入已經被樹化的 __call__ 方法,通過對樹內各個方法的執行與對返回值的組裝,形成一棵由最終結果組成的樹,即為 t1.reshape((4, -1))

有了這樣的功能,實際上整個 treevalue 已經足以實現非常豐富且靈活的功能,並且簡單易懂,易於維護。而針對torch進行了專用樹化封裝庫 treetensor ,目前也已經發布,感興趣可以去作進一步的瞭解:opendilab / DI-treetensor

延伸思考4:除了上述例子中的 reshapesin ,以及 numpytorch 等計算庫,還有哪些常見的庫以及物件可以通過上述動態特性實現類似的效果?

延伸思考5:如果上述例子中不是 reshape ,而是類似sum這樣的方法,並且在部分情況下可能希望獲取到整棵樹所有物件之和,這一需求該如何設計以滿足?

延伸思考6:對於類似 sum 方法這樣的情況,還有哪些運算是與之類似的?這些運算在邏輯上存在什麼共同點?與 reshapesin 這樣的方法在邏輯上的區別又在哪裡?

歡迎評論區討論!

後續預告

本文主要針對treevalue的核心特性——樹化函式,對其在類方法、魔術方法等的具體應用進行了展示,受限於篇幅,只能對這些頗有亮點的特性進行展示。在下一篇中,我們將針對treevalue在numpy、torch等計算模型庫的應用展開詳解,並與同類產品進行對比與分析,敬請期待。

此外,歡迎歡迎瞭解OpenDILab的開源專案:

相關文章