Treevalue(0x01)——功能概述

HansBug發表於2021-11-01

TreeValue——一個通用樹狀資料結構與函式計算庫

Treevalue v1.0.0版本已經於2021年10月24日正式釋出,歡迎下載體驗:opendilab / treevalue

這算是treevalue的第一個正式實用化版本,本文將會對其主要功能和特性進行一個概述。

一個直觀地展示

設想這樣一個實際的應用場景,我們需要使用numpy對機器學習中的一批樣本進行預處理,並組裝成一個訓練用的mini-batch。一個資料樣本的格式如下面的程式碼所示,即函式 get_data 的返回值:

import numpy as np

T, B = 3, 4

def get_data():
    return {
        'a': np.random.random(size=(T, 8)),
        'b': np.random.random(size=(6,)),
        'c': {
            'd': np.random.randint(0, 10, size=(1,))
        }
    }

如果使用最常見的寫法,大概會是這樣的

# without_treevalue.py
import numpy as np

T, B = 3, 4

def without_treevalue(batch_):
    mean_b_list = []
    even_index_a_list = []
    for i in range(len(batch_)):
        for k, v in batch_[i].items():
            if k == 'a':
                v = v.astype(np.float32)
                even_index_a_list.append(v[::2])
            elif k == 'b':
                v = v.astype(np.float32)
                transformed_v = np.power(v, 2) + 1.0
                mean_b_list.append(transformed_v.mean())
            elif k == 'c':
                for k1, v1 in v.items():
                    if k1 == 'd':
                        v1 = v1.astype(np.float32)
                    else:
                        print('ignore keys: {}'.format(k1))
            else:
                print('ignore keys: {}'.format(k))
    for i in range(len(batch_)):
        for k in batch_[i].keys():
            if k == 'd':
                batch_[i][k]['noise'] = np.random.random(size=(3, 4, 5))
    mean_b = sum(mean_b_list) / len(mean_b_list)
    even_index_a = np.stack(even_index_a_list, axis=0)
    return batch_, mean_b, even_index_a

而當我們有了treevalue庫之後,完全一致的功能可以被這樣簡短的程式碼實現

# with_treevalue.py
import numpy as np
from treevalue import FastTreeValue

T, B = 3, 4
power = FastTreeValue.func()(np.power)
stack = FastTreeValue.func(subside=True)(np.stack)
split = FastTreeValue.func(rise=True)(np.split)

def with_treevalue(batch_):
    batch_ = [FastTreeValue(b) for b in batch_]
    batch_ = stack(batch_)
    batch_ = batch_.astype(np.float32)
    batch_.b = power(batch_.b, 2) + 1.0
    batch_.c.noise = np.random.random(size=(B, 3, 4, 5))
    mean_b = batch_.b.mean()
    even_index_a = batch_.a[:, ::2]
    batch_ = split(batch_, indices_or_sections=B, axis=0)
    return batch_, mean_b, even_index_a

可以看到,實現一段同樣的基於樹結構的業務邏輯,有了treevalue庫的輔助後程式碼變得極為簡短和清晰,也更加易於維護。

這正是treevalue的最大亮點所在,接下來的章節中將會對其主要功能和特性進行概述,以便讀者對這個庫有一個整體的瞭解和認識。

樹結構及其基本操作

在treevalue庫中,我們提供一種核心資料結構—— TreeValue 類。該類為整個treevalue的核心特性,後續的一系列操作都是圍繞 TreeValue 類所展開的。

首先是 TreeValue 物件的構造(使用的是增強版子類 FastTreeValue ,關於 TreeValue 類與 FastTreeValue 類的區別可以參考文件,本文不作展開),只需要將dict格式的資料作為唯一的建構函式引數傳入即可完成TreeValue的構造

from treevalue import FastTreeValue

t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
# <FastTreeValue 0x7f135a5ada30>
# ├── a --> 1
# ├── b --> 2
# └── x --> <FastTreeValue 0x7f135a5ad970>
#     ├── c --> 3
#     └── d --> 4

不僅如此, TreeValue 類還提供了樹狀結構的幾種基本操作介面供使用

from treevalue import FastTreeValue

t = FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})

## get value/node
t.a  # 1
t.x  # <FastTreeValue 0x7f135a5ad970>
     # ├── c --> 3
     # └── d --> 4
t.x.d  # 4

## set value/node
t.x.d = 35
# t after t.x.d = 35
# <FastTreeValue 0x7f135a5ada30>
# ├── a --> 1
# ├── b --> 2
# └── x --> <FastTreeValue 0x7f135a5ad970>
#     ├── c --> 3
#     └── d --> 35

## delete value/node
del t.b
# t after del t.b
# <FastTreeValue 0x7f135a5ada30>
# ├── a --> 1
# └── x --> <FastTreeValue 0x7f135a5ad970>
#     ├── c --> 3
#     └── d --> 35

## contains key or not
'a' in t  # True
'd' in t  # False
'd' in t.x  # True

## size of node
len(t)  # 2, 'a' and 'x'
len(t.x)  # 2, 'c' and 'd'

## iterate node
for k, v in t.x:
    print(k, '-->', v)
# c --> 3
# d --> 35

以上是 TreeValue 類的幾種常見基本操作,支援了最基本的增刪查改等。

樹的視覺化

當一個 TreeValue 物件被構造出來後,我們如何比較直觀地去觀察其整體結構呢?有兩種方式可以對 TreeValue 進行視覺化:

  • 通過 print 進行快速文字列印
  • 通過 treevalue graph 命令列工具進行影像匯出

實際上對於第一種情況,在上一節中已經有了展示,在此展示一個更加複雜的案例

# test_simple.py
import torch

from treevalue import FastTreeValue

t = FastTreeValue({
    'a': torch.randn(2, 3),  # objects with multiple lines
    'b': torch.randn(3, 1, 2),
    'x': {
        'c': torch.randn(3, 4),
    }
})
t.x.d = t.x  # nested pointer
print(t)

輸出結果如下,可以看到諸如 torch.Tensor 這樣多行的物件也一樣可以被有序地排版輸出,且對於存在巢狀引用的情況,輸出時也可以被準確地識別出來,避免無限迴圈列印

<FastTreeValue 0x7f642057bd00>
├── a --> tensor([[ 0.1050, -1.5681, -0.2849],
│                 [-0.9415,  0.2376,  0.7509]])
├── b --> tensor([[[ 0.6496, -1.3547]],
│         
│                 [[ 1.2426, -0.2171]],
│         
│                 [[-0.7168, -1.4415]]])
└── x --> <FastTreeValue 0x7f642057bd30>
    ├── c --> tensor([[-0.6518,  0.4035,  1.0721, -0.6657],
    │                 [ 0.0252,  0.4356,  0.1284, -0.3290],
    │                 [-0.6725,  0.2923,  0.0048,  0.0704]])
    └── d --> <FastTreeValue 0x7f642057bd30>
              (The same address as <root>.x)

除了基於文字的視覺化外,我們還提供了命令列工具以進行影像匯出。例如上面的程式碼,我們可以用如下的命令列匯出影像

treevalue graph -t 'test_simple.t' -o 'test_graph.png'

此外,對於更復雜的情況,例如這樣的一份原始碼

# test_main.py
import numpy as np

from treevalue import FastTreeValue

tree_0 = FastTreeValue({
    'a': [4, 3, 2, 1],
    'b': np.array([[5, 6], [7, 8]]),
    'x': {
        'c': np.array([[5, 7], [8, 6]]),
        'd': {'a', 'b', 'c'},
        'e': np.array([[1, 2], [3, 4]])
    },
})
tree_1 = FastTreeValue({
    'aa': tree_0.a,
    'bb': np.array([[5, 6], [7, 8]]),
    'xx': {
        'cc': tree_0.x.c,
        'dd': tree_0.x.d,
        'ee': np.array([[1, 2], [3, 4]])
    },
})
tree_2 = FastTreeValue({'a': tree_0, 'b': tree_1, 'c': [1, 2], 'd': tree_1.xx})

可以通過以下的命令列匯出為影像,不難發現對於共用節點和共用物件的情況也都進行了準確地體現(如需進一步瞭解,可以執行 treevalue graph -h 檢視幫助資訊)

treevalue graph -t 'test_main.tree_*' -o 'test_graph.png' -d numpy.ndarray -d list -d dict

以上是對於 TreeValue 物件的兩種視覺化方法。

函式的樹化

在treevalue中,我們可以快速地將函式進行裝飾,使之可以支援 TreeValue 物件作為引數進行運算。例如下面的例子

# test_gcd.py
from treevalue import FastTreeValue, func_treelize

@func_treelize(return_type=FastTreeValue)
def gcd(a, b):  # GCD calculation
    while True:
        r = a % b
        a, b = b, r
        if r == 0:
            break

    return a

if __name__ == '__main__':
    t1 = FastTreeValue({'a': 2, 'b': 30, 'x': {'c': 4, 'd': 9}})
    t2 = FastTreeValue({'a': 4, 'b': 48, 'x': {'c': 6, 'd': 54}})

    print("Result of gcd(12, 9):", gcd(12, 9))
    print("Result of gcd(t1, t2):")
    print(gcd(t1, t2))

將整數之間的最大公因數進行了裝飾後,可以形成相容 TreeValue 物件的函式,且不會影響普通物件的運算結果,因此上述程式碼的輸出結果如下所示

Result of gcd(12, 9): 3
Result of gcd(t1, t2):
<FastTreeValue 0x7f53fa67ff10>
├── a --> 2
├── b --> 6
└── x --> <FastTreeValue 0x7f53fa67ff40>
    ├── c --> 2
    └── d --> 9

樹結構的運算

除了 TreeValue 自帶的一系列基本操作之外,treevalue庫還提供了一些常用的樹結構運算函式。例如如下的四種:

  • map——值對映運算
  • reduce——值歸納運算
  • subside——頂層結構下沉運算
  • rise——值結構提取上浮運算

值對映運算(map)

TreeValue 物件的值對映運算和列表型別的map運算類似,會產生一個同結構且值為對映值的新 TreeValue 物件,例如下面的案例

from treevalue import FastTreeValue

t1 = FastTreeValue({'a': 2, 'b': 30, 'x': {'c': 4, 'd': 9}})
t2 = t1.map(lambda x: x * 2 + 1)

t1t2 的影像如下

值歸納運算(reduce)

TreeValue 物件的值歸納運算和 functools.reduce 函式的功能類似,可以將樹結構以子樹為單位進行歸納,最終計算出一個結果來,例如下面的案例

from treevalue import FastTreeValue

t1 = FastTreeValue({'a': 2, 'b': 30, 'x': {'c': 4, 'd': 9}})

# sum of all the values in t1
t1.reduce(lambda **kws: sum(kws.values()))  # 45

可以快捷的實現整棵樹的求和運算。

頂層結構下沉運算(subside)

TreeValue 還可以支援將其上層結構進行解析後,下沉到節點值上,形成一棵新的樹,例如下面的案例

from treevalue import FastTreeValue

dt = {
    'i1': FastTreeValue({'a': 1, 'b': 2, 'x': {'c': 3}}),
    'i2': (
        FastTreeValue({'a': 7, 'b': 4, 'x': {'c': 6}}),
        FastTreeValue({'a': 11, 'b': 9, 'x': {'c': -3}}),
    ),
}

t = FastTreeValue.subside(dt)

下沉後產生的樹 t 的影像為

值結構提取上浮運算(rise)

上浮運算(rise)為下沉運算的逆運算,可以從值節點中提取共同結構至 TreeValue 物件外,例如如下的案例

from treevalue import FastTreeValue, raw

t = FastTreeValue({
    'a': raw({'i1': 1, 'i2': (7, 11)}),
    'b': raw({'i1': 2, 'i2': (4, 9)}),
    'x': {
        'c': raw({'i1': 3, 'i2': (6, -3)}),
    }
})

dt = t.rise()
dt_i1 = dt['i1']
dt_i2_0, dt_i2_1 = dt['i2']

物件 dt 是一個字典型別的物件,包含 i1i2 兩個鍵,其中 i1 為一個 TreeValue 物件, i2 為一個長度為2,由 TreeValue 物件構成的元組。因此 dt_i1dt_i2_0dt_i2_1 的影像為

後續預告

本文作為一個對於treevalue現有主要功能的一個概述,受限於篇幅暫時做不到深入全面的講解內部原理和機制。因此後續會考慮繼續出包括但不限於以下的內容:

  • treevalue的整體體系結構
  • func_treelize 樹化函式的原理與機制
  • treevalue的一些神奇的用法與黑科技
  • treevalue的優化歷程與經驗
  • treevalue的具體實戰案例

敬請期待,同時歡迎瞭解其他OpenDILab的開源專案:https://github.com/opendilab

相關文章