TreeValue——一個通用樹狀資料結構與函式計算庫
Treevalue v1.0.0版本已經於2021年10月24日正式釋出,歡迎下載體驗:opendilab / treevalue。
這算是treevalue的第一個正式實用化版本,本文將會對其主要功能和特性進行一個概述。
一個直觀地展示
設想這樣一個實際的應用場景,我們需要使用numpy對機器學習中的一批樣本進行預處理,並組裝成一個訓練用的mini-batch。一個資料樣本的格式如下面的程式碼所示,即函式 get_data
的返回值:
import numpy as np
T, B = 3, 4
def get_data():
return {
'a': np.random.random(size=(T, 8)),
'b': np.random.random(size=(6,)),
'c': {
'd': np.random.randint(0, 10, size=(1,))
}
}
如果使用最常見的寫法,大概會是這樣的
# without_treevalue.py
import numpy as np
T, B = 3, 4
def without_treevalue(batch_):
mean_b_list = []
even_index_a_list = []
for i in range(len(batch_)):
for k, v in batch_[i].items():
if k == 'a':
v = v.astype(np.float32)
even_index_a_list.append(v[::2])
elif k == 'b':
v = v.astype(np.float32)
transformed_v = np.power(v, 2) + 1.0
mean_b_list.append(transformed_v.mean())
elif k == 'c':
for k1, v1 in v.items():
if k1 == 'd':
v1 = v1.astype(np.float32)
else:
print('ignore keys: {}'.format(k1))
else:
print('ignore keys: {}'.format(k))
for i in range(len(batch_)):
for k in batch_[i].keys():
if k == 'd':
batch_[i][k]['noise'] = np.random.random(size=(3, 4, 5))
mean_b = sum(mean_b_list) / len(mean_b_list)
even_index_a = np.stack(even_index_a_list, axis=0)
return batch_, mean_b, even_index_a
而當我們有了treevalue庫之後,完全一致的功能可以被這樣簡短的程式碼實現
# with_treevalue.py
import numpy as np
from treevalue import FastTreeValue
T, B = 3, 4
power = FastTreeValue.func()(np.power)
stack = FastTreeValue.func(subside=True)(np.stack)
split = FastTreeValue.func(rise=True)(np.split)
def with_treevalue(batch_):
batch_ = [FastTreeValue(b) for b in batch_]
batch_ = stack(batch_)
batch_ = batch_.astype(np.float32)
batch_.b = power(batch_.b, 2) + 1.0
batch_.c.noise = np.random.random(size=(B, 3, 4, 5))
mean_b = batch_.b.mean()
even_index_a = batch_.a[:, ::2]
batch_ = split(batch_, indices_or_sections=B, axis=0)
return batch_, mean_b, even_index_a
可以看到,實現一段同樣的基於樹結構的業務邏輯,有了treevalue庫的輔助後程式碼變得極為簡短和清晰,也更加易於維護。
這正是treevalue的最大亮點所在,接下來的章節中將會對其主要功能和特性進行概述,以便讀者對這個庫有一個整體的瞭解和認識。
樹結構及其基本操作
在treevalue庫中,我們提供一種核心資料結構—— TreeValue
類。該類為整個treevalue的核心特性,後續的一系列操作都是圍繞 TreeValue
類所展開的。
首先是 TreeValue
物件的構造(使用的是增強版子類 FastTreeValue
,關於 TreeValue
類與 FastTreeValue
類的區別可以參考文件,本文不作展開),只需要將dict格式的資料作為唯一的建構函式引數傳入即可完成TreeValue
的構造
from treevalue import FastTreeValue
t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
# <FastTreeValue 0x7f135a5ada30>
# ├── a --> 1
# ├── b --> 2
# └── x --> <FastTreeValue 0x7f135a5ad970>
# ├── c --> 3
# └── d --> 4
不僅如此, TreeValue
類還提供了樹狀結構的幾種基本操作介面供使用
from treevalue import FastTreeValue
t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
## get value/node
t.a # 1
t.x # <FastTreeValue 0x7f135a5ad970>
# ├── c --> 3
# └── d --> 4
t.x.d # 4
## set value/node
t.x.d = 35
# t after t.x.d = 35
# <FastTreeValue 0x7f135a5ada30>
# ├── a --> 1
# ├── b --> 2
# └── x --> <FastTreeValue 0x7f135a5ad970>
# ├── c --> 3
# └── d --> 35
## delete value/node
del t.b
# t after del t.b
# <FastTreeValue 0x7f135a5ada30>
# ├── a --> 1
# └── x --> <FastTreeValue 0x7f135a5ad970>
# ├── c --> 3
# └── d --> 35
## contains key or not
'a' in t # True
'd' in t # False
'd' in t.x # True
## size of node
len(t) # 2, 'a' and 'x'
len(t.x) # 2, 'c' and 'd'
## iterate node
for k, v in t.x:
print(k, '-->', v)
# c --> 3
# d --> 35
以上是 TreeValue
類的幾種常見基本操作,支援了最基本的增刪查改等。
樹的視覺化
當一個 TreeValue
物件被構造出來後,我們如何比較直觀地去觀察其整體結構呢?有兩種方式可以對 TreeValue
進行視覺化:
- 通過
print
進行快速文字列印 - 通過
treevalue graph
命令列工具進行影像匯出
實際上對於第一種情況,在上一節中已經有了展示,在此展示一個更加複雜的案例
# test_simple.py
import torch
from treevalue import FastTreeValue
t = FastTreeValue({
'a': torch.randn(2, 3), # objects with multiple lines
'b': torch.randn(3, 1, 2),
'x': {
'c': torch.randn(3, 4),
}
})
t.x.d = t.x # nested pointer
print(t)
輸出結果如下,可以看到諸如 torch.Tensor
這樣多行的物件也一樣可以被有序地排版輸出,且對於存在巢狀引用的情況,輸出時也可以被準確地識別出來,避免無限迴圈列印
<FastTreeValue 0x7f642057bd00>
├── a --> tensor([[ 0.1050, -1.5681, -0.2849],
│ [-0.9415, 0.2376, 0.7509]])
├── b --> tensor([[[ 0.6496, -1.3547]],
│
│ [[ 1.2426, -0.2171]],
│
│ [[-0.7168, -1.4415]]])
└── x --> <FastTreeValue 0x7f642057bd30>
├── c --> tensor([[-0.6518, 0.4035, 1.0721, -0.6657],
│ [ 0.0252, 0.4356, 0.1284, -0.3290],
│ [-0.6725, 0.2923, 0.0048, 0.0704]])
└── d --> <FastTreeValue 0x7f642057bd30>
(The same address as <root>.x)
除了基於文字的視覺化外,我們還提供了命令列工具以進行影像匯出。例如上面的程式碼,我們可以用如下的命令列匯出影像
treevalue graph -t 'test_simple.t' -o 'test_graph.png'
此外,對於更復雜的情況,例如這樣的一份原始碼
# test_main.py
import numpy as np
from treevalue import FastTreeValue
tree_0 = FastTreeValue({
'a': [4, 3, 2, 1],
'b': np.array([[5, 6], [7, 8]]),
'x': {
'c': np.array([[5, 7], [8, 6]]),
'd': {'a', 'b', 'c'},
'e': np.array([[1, 2], [3, 4]])
},
})
tree_1 = FastTreeValue({
'aa': tree_0.a,
'bb': np.array([[5, 6], [7, 8]]),
'xx': {
'cc': tree_0.x.c,
'dd': tree_0.x.d,
'ee': np.array([[1, 2], [3, 4]])
},
})
tree_2 = FastTreeValue({'a': tree_0, 'b': tree_1, 'c': [1, 2], 'd': tree_1.xx})
可以通過以下的命令列匯出為影像,不難發現對於共用節點和共用物件的情況也都進行了準確地體現(如需進一步瞭解,可以執行 treevalue graph -h
檢視幫助資訊)
treevalue graph -t 'test_main.tree_*' -o 'test_graph.png' -d numpy.ndarray -d list -d dict
以上是對於 TreeValue
物件的兩種視覺化方法。
函式的樹化
在treevalue中,我們可以快速地將函式進行裝飾,使之可以支援 TreeValue
物件作為引數進行運算。例如下面的例子
# test_gcd.py
from treevalue import FastTreeValue, func_treelize
@func_treelize(return_type=FastTreeValue)
def gcd(a, b): # GCD calculation
while True:
r = a % b
a, b = b, r
if r == 0:
break
return a
if __name__ == '__main__':
t1 = FastTreeValue({'a': 2, 'b': 30, 'x': {'c': 4, 'd': 9}})
t2 = FastTreeValue({'a': 4, 'b': 48, 'x': {'c': 6, 'd': 54}})
print("Result of gcd(12, 9):", gcd(12, 9))
print("Result of gcd(t1, t2):")
print(gcd(t1, t2))
將整數之間的最大公因數進行了裝飾後,可以形成相容 TreeValue
物件的函式,且不會影響普通物件的運算結果,因此上述程式碼的輸出結果如下所示
Result of gcd(12, 9): 3
Result of gcd(t1, t2):
<FastTreeValue 0x7f53fa67ff10>
├── a --> 2
├── b --> 6
└── x --> <FastTreeValue 0x7f53fa67ff40>
├── c --> 2
└── d --> 9
樹結構的運算
除了 TreeValue
自帶的一系列基本操作之外,treevalue庫還提供了一些常用的樹結構運算函式。例如如下的四種:
- map——值對映運算
- reduce——值歸納運算
- subside——頂層結構下沉運算
- rise——值結構提取上浮運算
值對映運算(map)
TreeValue
物件的值對映運算和列表型別的map運算類似,會產生一個同結構且值為對映值的新 TreeValue
物件,例如下面的案例
from treevalue import FastTreeValue
t1 = FastTreeValue({'a': 2, 'b': 30, 'x': {'c': 4, 'd': 9}})
t2 = t1.map(lambda x: x * 2 + 1)
t1
和 t2
的影像如下
值歸納運算(reduce)
TreeValue
物件的值歸納運算和 functools.reduce
函式的功能類似,可以將樹結構以子樹為單位進行歸納,最終計算出一個結果來,例如下面的案例
from treevalue import FastTreeValue
t1 = FastTreeValue({'a': 2, 'b': 30, 'x': {'c': 4, 'd': 9}})
# sum of all the values in t1
t1.reduce(lambda **kws: sum(kws.values())) # 45
可以快捷的實現整棵樹的求和運算。
頂層結構下沉運算(subside)
TreeValue
還可以支援將其上層結構進行解析後,下沉到節點值上,形成一棵新的樹,例如下面的案例
from treevalue import FastTreeValue
dt = {
'i1': FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3}}),
'i2': (
FastTreeValue({'a': 7, 'b': 4, 'x': {'c': 6}}),
FastTreeValue({'a': 11, 'b': 9, 'x': {'c': -3}}),
),
}
t = FastTreeValue.subside(dt)
下沉後產生的樹 t
的影像為
值結構提取上浮運算(rise)
上浮運算(rise)為下沉運算的逆運算,可以從值節點中提取共同結構至 TreeValue
物件外,例如如下的案例
from treevalue import FastTreeValue, raw
t = FastTreeValue({
'a': raw({'i1': 1, 'i2': (7, 11)}),
'b': raw({'i1': 2, 'i2': (4, 9)}),
'x': {
'c': raw({'i1': 3, 'i2': (6, -3)}),
}
})
dt = t.rise()
dt_i1 = dt['i1']
dt_i2_0, dt_i2_1 = dt['i2']
物件 dt
是一個字典型別的物件,包含 i1
和 i2
兩個鍵,其中 i1
為一個 TreeValue
物件, i2
為一個長度為2,由 TreeValue
物件構成的元組。因此 dt_i1
、 dt_i2_0
和 dt_i2_1
的影像為
後續預告
本文作為一個對於treevalue現有主要功能的一個概述,受限於篇幅暫時做不到深入全面的講解內部原理和機制。因此後續會考慮繼續出包括但不限於以下的內容:
- treevalue的整體體系結構
func_treelize
樹化函式的原理與機制- treevalue的一些神奇的用法與黑科技
- treevalue的優化歷程與經驗
- treevalue的具體實戰案例
敬請期待,同時歡迎瞭解其他OpenDILab的開源專案:https://github.com/opendilab。