numba-jitclass

一枚码农發表於2024-05-12

參考文件:https://numba.pydata.org/numba-doc/latest/user/jitclass.html#

jitclass

對類進行裝飾

import numba as nb
import numpy as np
from numba.experimental import jitclass

spec = [
    ("value", nb.int32),
    ("array", nb.float32[:]),
]


@jitclass(spec)
class Bag:
    def __init__(self, value):
        # self.value = value
        self.array = np.zeros(value, dtype=np.float32)

    @property
    def size(self):
        return self.array.size

    def increment(self, val):
        for i in range(self.size):
            self.array[i] += val

        return self.array


a = 2
b = Bag(a)
print("b.increment(1): ", b.increment(1))  # b.increment(1):  [1. 1.]

上面例子中,spec 提供了一個兩元組元素的陣列,元組包含欄位名稱和型別。也可以使用有序字典對映欄位與類選的關係。
類中至少要初始化每個定義的欄位,如果不初始化,欄位會包含垃圾資料。

具體的 numb.typed 容器(container)做類成員

  1. 顯式的型別和構建
kv_ty = (nb.types.int64, nb.types.unicode_type)


@jitclass(
    [("d", nb.types.DictType(*kv_ty)), ("l", nb.types.ListType(nb.types.float64))]
)
class ContainerHolder:
    def __init__(self):
        self.d = nb.typed.Dict.empty(*kv_ty)
        self.l = nb.typed.List.empty_list(nb.types.float64)


c = ContainerHolder()
c.d[1] = "apple"
c.d[2] = "orange"
c.l.append(1.0)
c.l.append(2.0)

print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange}
print("c.l: ", c.l)  # c.l:  [1.0, 2.0]
  1. 另一個有用的模式是用 numba.typed 的 _numba_type_ 屬效能夠找找到容器的型別, 這樣可以直接在 python 直譯器中訪問容器的例項。使用 numba.typeof 可以得到跟容器例項一樣的資訊。如下:
d = nb.typed.Dict()
d[1] = "apple"
d[2] = "orange"

l = nb.typed.List()
l.append(1.0)
l.append(2.0)


@jitclass([("d", nb.typeof(d)), ("l", nb.typeof(l))])
class ContainerInsHolder:
    def __init__(self, dict_instance, list_instance):
        self.d = dict_instance
        self.l = list_instance


c = ContainerInsHolder(d, l)
print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange}
print("c.l: ", c.l)  # c.l:  [1.0, 2.0]
c.d[3] = "banana"
print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange, 3: banana}

需要注意的是,容器例項在使用前必須要初始化,否則會有異常,如下面的是錯誤的:

d_ty = nb.types.DictType(nb.types.int64, nb.types.unicode_type)


@jitclass([("d", d_ty)])
class NotInitContainer:
    def __init__(self):
        self.d[10] = "apple"  # d 沒有被初始化,這裡是無效的


NotInitContainer()  # 例項化會失敗,記憶體訪問無效,程式會異常結束 Process finished with exit code -1073741819 (0xC0000005)

支援的操作

以下 jitclasses 操作在 python 直譯器和 numba 編譯的函式中都支援:

  • 用 jitclass 類例項化物件。(如: bag = Bag(123))
  • 讀/寫屬性。(如:bag.value)
  • 方法呼叫。(如:bag.increment(2))
  • 呼叫例項的靜態方法。(如:bag.add(1, 2))
  • 呼叫類的靜態方法。(如:Bag.add(1,2))

侷限性

  • jitclass 被看作是一個 numba 的編譯函式
  • isinstance() 只能在 python 直譯器中使用
完整程式碼
import numba as nb
import numpy as np
from numba.experimental import jitclass

spec = [
    ("value", nb.int32),
    ("array", nb.float32[:]),
]


@jitclass(spec)
class Bag:
    def __init__(self, value):
        # self.value = value
        self.array = np.zeros(value, dtype=np.float32)

    @property
    def size(self):
        return self.array.size

    def increment(self, val):
        for i in range(self.size):
            self.array[i] += val

        return self.array


a = 2
b = Bag(a)
print("b.increment(1): ", b.increment(1))  # b.increment(1):  [1. 1.]

print("-" * 20)
kv_ty = (nb.types.int64, nb.types.unicode_type)


@jitclass(
    [("d", nb.types.DictType(*kv_ty)), ("l", nb.types.ListType(nb.types.float64))]
)
class ContainerHolder:
    def __init__(self):
        self.d = nb.typed.Dict.empty(*kv_ty)
        self.l = nb.typed.List.empty_list(nb.types.float64)


c = ContainerHolder()
c.d[1] = "apple"
c.d[2] = "orange"
c.l.append(1.0)
c.l.append(2.0)

print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange}
print("c.l: ", c.l)  # c.l:  [1.0, 2.0]


print("-" * 20)

d = nb.typed.Dict()
d[1] = "apple"
d[2] = "orange"

l = nb.typed.List()
l.append(1.0)
l.append(2.0)


@jitclass([("d", nb.typeof(d)), ("l", nb.typeof(l))])
class ContainerInsHolder:
    def __init__(self, dict_instance, list_instance):
        self.d = dict_instance
        self.l = list_instance


c = ContainerInsHolder(d, l)
print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange}
print("c.l: ", c.l)  # c.l:  [1.0, 2.0]
c.d[3] = "banana"
print("c.d: ", c.d)  # c.d:  {1: apple, 2: orange, 3: banana}


print("-" * 20)
d_ty = nb.types.DictType(nb.types.int64, nb.types.unicode_type)


@jitclass([("d", d_ty)])
class NotInitContainer:
    def __init__(self):
        self.d[10] = "apple"  # d 沒有被初始化,這裡是無效的


NotInitContainer()  # 例項化會失敗,記憶體訪問無效,程式會異常結束 Process finished with exit code -1073741819 (0xC0000005)