參考文件: https://apachecn.github.io/numba-doc-zh/#/docs/17
vectorize裝飾器:
允許python函式的標量入參使用 numpy 的 ufuncs。
import numba as nb
@nb.vectorize([nb.int32(nb.int32, nb.int32)], nopython=True)
def f(x, y):
return x + y
a = [1, 2, 3]
b = [4, 5, 6]
print("a + b: ", a + b) # a + b: [1, 2, 3, 4, 5, 6]
print("f(a, b): ", f(a, b)) # f(a, b): [5 7 9]
print("type(f(a, b)): ", type(f(a, b))) # type(f(a, b)): <class 'numpy.ndarray'>
c = [[1, 2, 3], [4, 5, 6]]
d = [[7, 8, 9], [10, 11, 12]]
print("f(c,d): ", f(c, d)) # [[8 10 12] [14 16 18]]
vectorize支援多個函式簽名,多如果用多個函式簽名的時候注意順序,比如單精度浮點數在雙精度浮點數之前。
@vectorize([int32(int32, int32),
int64(int64, int64),
float32(float32, float32),
float64(float64, float64)])
使用vectorize裝飾之後,函式會自動獲得numpy ufuncs的其他功能,比如縮小,累計或廣播。ufunc的標準功能。
print(
"f.reduce(c) axis=0: %s, axis=1: %s" % (f.reduce(c), f.reduce(c, axis=1))
) # f.reduce(c) axis=0: [5 7 9], axis=1: [ 6 15]
guvectorize 裝飾器
vectorize 是一次寫一個元素的 ufunc, guvectorize 是允許編寫輸入任意數量陣列元素的 ufunc,並獲取和返回不同維度的陣列。典型的例子是中值或卷積濾波器。
guvectorize 不返回結果,是將其作為函式引數,必須有函式填充。
@nb.guvectorize([(nb.int32[:], nb.int32, nb.int32[:])], "(n), ()->(n)", nopython=True)
def g(x, y, r):
for i in range(x.shape[0]):
r[i] = x[i] + y
a = [1, 2, 3, 4]
b = 10
print("g(a, b): ", g(a, b)) # g(a, b): [11 12 13 14]
底層的 python 程式碼是將一個給定的標量 y 新增到一維陣列的所有元素中。
"(n),()->(n)" 表示輸入一個n個元素的一維陣列和一個標量(用符號()表示)並返回n個元素的一維陣列。
1D 陣列型別也可以接收標量引數,在上邊例子中,第二個引數也可以宣告為int[:] (), 在這種情況下,該值必須由 y[0] 讀取。
我們可以檢查一下能否支援不同維度的陣列
# 2d + int
c = [[1, 2, 3], [4, 5, 6]]
print("g(c, 10): ", g(c, 10)) # g(c, 10): [[11 12 13] [14 15 16]]
# 2d + 1d
d = [10, 20]
print("g(c, d): ", g(c, d)) # g(c, d): [[11 12 13] [24 25 26]]
動態通用功能
如果給裝飾器 vectorize 沒有傳遞任何簽名,python 函式將會構建動態通用函式或者DUDunc。
@nb.vectorize()
def d(x, y):
return x + y
print("d(1, 2)", d(1, 2)) # d(1, 2) 3
print(d.types) # ['ll->q']
print("d(1, 2.0)", d(1, 2.0)) # d(1, 2) 3.0
print(d.types) # ['ll->q', 'ld->d']
print("d([1, 2, 3], 10): ", d([1, 2, 3], 10)) # d([1, 2, 3], 10): [11 12 13]
print(d.types) # ['ll->q', 'ld->d']
在這種情況下,呼叫的順序很重要,比如先傳入浮點引數,那麼任何帶有整形的引數都會被轉換成雙精度浮點值。如: 先呼叫 d(1.0, 2.0) 得到3.0,再呼叫 d(1, 2) 也會得到3.0。
完整程式碼
import numba as nb
@nb.vectorize(
[nb.int32(nb.int32, nb.int32), nb.float64(nb.float64, nb.float64)], nopython=True
)
def f(x, y):
return x + y
a = [1.0, 2.0, 3.0]
b = [4, 5, 6]
# b = [4.0, 5.0, 6.0]
print("a + b: ", a + b) # a + b: [1, 2, 3, 4, 5, 6]
print("f(a, b): ", f(a, b)) # f(a, b): [5 7 9]
print("type(f(a, b)): ", type(f(a, b))) # type(f(a, b)): <class 'numpy.ndarray'>
c = [[1, 2, 3], [4, 5, 6]]
d = [[7, 8, 9], [10, 11, 12]]
print("f(c,d): ", f(c, d)) # [[8 10 12] [14 16 18]]
print(
"f.reduce(c) axis=0: %s, axis=1: %s" % (f.reduce(c), f.reduce(c, axis=1))
) # f.reduce(c) axis=0: [5 7 9], axis=1: [ 6 15]
print("-" * 20)
@nb.guvectorize([(nb.int32[:], nb.int32, nb.int32[:])], "(n), ()->(n)", nopython=True)
def g(x, y, r):
for i in range(x.shape[0]):
r[i] = x[i] + y
a = [1, 2, 3, 4]
b = 10
print("g(a, b): ", g(a, b)) # g(a, b): [11 12 13 14]
# 2d + int
c = [[1, 2, 3], [4, 5, 6]]
print("g(c, 10): ", g(c, 10)) # g(c, 10): [[11 12 13] [14 15 16]]
# 2d + 1d
d = [10, 20]
print("g(c, d): ", g(c, d)) # g(c, d): [[11 12 13] [24 25 26]]
print("-" * 20)
@nb.vectorize()
def d(x, y):
return x + y
print("d(1, 2)", d(1, 2)) # d(1, 2) 3
print(d.types) # ['ll->q']
print("d(1, 2.0)", d(1, 2.0)) # d(1, 2) 3.0
print(d.types) # ['ll->q', 'ld->d']
print("d([1, 2, 3], 10): ", d([1, 2, 3], 10)) # d([1, 2, 3], 10): [11 12 13]
print(d.types) # ['ll->q', 'ld->d']