完成日期: 2024-03-01
更新日期: 2024-03-01
問題
Numpy 中有眾多操作會涉及到一個引數 axis
, 也就是 軸. 這到底是什麼? 沿著某軸操作 (例如 np.sum(axis=0)
) 又是什麼意思?
對於低維陣列, 或許可以按 行 和 列 來理解, 但如果上升到了四維、五維乃至更高, 就變得十分抽象了. 因為這裡要討論更高維度的情況, 所以就不使用 "行" 或 "列" 之類的詞語描述陣列的幾何意義了
一維陣列
我們先來看一維陣列的情況, 也就是 np.shape = (1,)
的情況 (這裡的等號 =
是數學上的符號, 不是賦值的意思), arange()
方法可以接受一個引數 n
, 生成從 0 到 n 的整數序列
>>> a = np.arange(8)
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7])
我們現在想要計算這個序列的和, 使用 np.sum(axis=0)
方法, 由於序列 a
只有一個維度, axis
引數只能為 0
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7])
>>> a.shape
(8,)
>>> a.sum(axis=0)
28
輸出值是 28, np.shape
方法可以列印陣列的維度資訊, 這裡序列 a
只有一個, 這個維度有 8 個元素. axis=0
指在第 0
維上, 遍歷所有元素, 執行求和 sum()
的操作, 也就是把索引為 0, 1, 2, ..., 7
的元素累加起來, 因為它們的層級都是第 0
維, 得到 [28]
, 寫成虛擬碼的形式就是:
for i in range(8)
sum = sum + i
Numpy 中 sum()
、min()
等方法還會有一個降維的效果, 也就是 reduce, 於是用於表示維度的方括號被抽離, [28]
變成 28
, 成為 0
維的標量
二維陣列
axis=0
>>> a = np.arange(6).reshape(2,3)
>>> a.shape
(2, 3)
>>> a
array([[0, 1, 2],
[3, 4, 5]])
>>> a.sum(axis=0)
array([3, 5, 7])
reshape()
方法可以調整陣列維度, arange(6)
生成了有 6 個元素的一維陣列, reshape(2,3)
將它調整成 2*3 的二維陣列
在Numpy顯示多維陣列的方式中, 可以透過數方括號來確定維度, 第一個方括號是第 axis=0
維, 第二個方括號是第 axis=1
維. 我們調整一下上面 Numpy 顯示二維陣列的方式, 方便指示維度
>>> a
array([ -----------> 表示第 axis=0 維
[ --------> 表示第 axis=1 維
0, 1, 2],
[3, 4, 5]])
那麼, axis=0
要求按第 0
軸去求和, 我們把陣列 a
再換個表示方式
>>>a
array([[0, 1, 2],
[3, 4, 5]])
array([ A,
B ])
這裡, 將 [0, 1, 2]
看作 A
, 將 [3, 4, 5]
看作 B
(還記得矩陣中的子矩陣嗎?)
那麼, 我們要遍歷 axis=0
軸的元素執行求和, 就是要計算 [A+B]
, 也就是 [[0, 1, 2] + [3, 4, 5]]
, 結果是 [[3, 5, 7]]
, 由於 np.sum()
會降維, 抽離最外邊第 axis=0
的方括號, 於是變成 [3, 5, 7]
如果寫成虛擬碼, 即使有 N 行
for i in range(A, B, C, ... , N)
sum = sum + i
如果用切片的方式來表示, 是在計算 a[0, ...] + a[1, ...]
, 符號 ...
表示此維度不作指定, 會自動推導選擇全部
>>> a[0, ...]
array([0, 1, 2])
>>> a[1, ...]
array([3, 4, 5])
[0, 1, 2] + [3, 4, 5] = [3, 5, 7]
陣列 a
的 shape
從 (2,3)
, 執行完 sum(axis=0)
後, 變成 (3,)
, 第一個維度沒有了
axis=1
>>> a
array([[0, 1, 2],
[3, 4, 5]])
>>> a.shape
(2, 3)
>>> a.sum(axis=1)
array([ 3, 12])
如果 np.sum(axis=1)
, 我們需要遍歷同屬於 axis=1
層級的元素並求和
>>> a
array([
[0, 1, 2], 左 ---> 右, 遍歷求和, 得 [3]
[3, 4, 5] 左 ---> 右, 遍歷求和, 的 [12]
])
array([
[3],
[12]
])
最終, 得到陣列 [[3], [12]]
, 由於降維, 抽離裡面 axis=1
的方括號, 變成 [3, 12]
如果用切片的方式檢視, 實質上是在計算 a[..., 0] + a[..., 1] + a[..., 2]
>>> a[..., 0]
array([0, 3])
>>> a[..., 1]
array([1, 4])
>>> a[..., 2]
array([2, 5])
[0, 3] + [1, 4] + [2, 5] = [3, 12]
陣列 a
的 shape
從 (2,3)
, 執行完 sum(axis=1)
後, 變成 (2,)
三維陣列
到這裡, 我們可以看到:
對於一維陣列, axis=0
, 我們需要遍歷陣列中的每一個元素, 也就是 a[...]
對於二維陣列, axis=0
, 我們需要遍歷陣列第 0
軸的每一個元素, 也就是 a[0, ...] + a[1, ...] + a[2, ...] ... a[n, ...]
如果 axis=1
, 我們需要遍歷陣列第 1
軸的每一個元素, 也就是 a[..., 0] + a[..., 1] + a[..., 2] ... a[..., n]
那麼, 是不是可以認為, 對某一軸操作, 就是去遍歷那個軸的切片?
axis=0
for i in {a[0, :, :, ...], a[1, :, :, ...], ..., a[n, :, :, ...]}
axis=1
for i in {a[:, 0, :, ...], a[:, 1, :, ...], ..., a[:, n, :, ...]}
axis=2
for i in {a[:, :, 0, ...], a[:, :, 1, ...], ..., a[:, :, n, ...]}
...
...
我們用三維陣列驗證一下
>>> a = np.arange(24).reshape(2,3,4)
>>> a
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> a.sum(axis=0)
array([[12, 14, 16, 18],
[20, 22, 24, 26],
[28, 30, 32, 34]])
# 列印切片
>>> a[0,...]
array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
>>> a[1,...]
array([[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]])
[[ 0, 1, 2, 3], [[12, 13, 14, 15], [[12, 14, 16, 18],
[ 4, 5, 6, 7], + [16, 17, 18, 19], = [20, 22, 24, 26],
[ 8, 9, 10, 11]] [20, 21, 22, 23]] [28, 30, 32, 34]]
同一層級 axis=0
的元素有兩個, 是切片 a[0, ...]
與 切片 a[1, ...]
, 它們的和正好就是 a.sum(axis=0)
現在我們在來看看 axis=1
的情況
>>> a
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> a.sum(axis=1)
array([[12, 15, 18, 21],
[48, 51, 54, 57]])
# 列印切片
>>> a[:, 0, :]
array([[ 0, 1, 2, 3],
[12, 13, 14, 15]])
>>> a[:, 1, :]
array([[ 4, 5, 6, 7],
[16, 17, 18, 19]])
>>> a[:, 2, :]
array([[ 8, 9, 10, 11],
[20, 21, 22, 23]])
>>> a[:,0,:] + a[:,1,:] + a[:,2,:] == a.sum(axis=1)
array([[ True, True, True, True],
[ True, True, True, True]])
切片的和也確實等於 a.sum(axis=1)
, 回到陣列 a
本身, 第一個維度(也就是 axis=0
) 不動, axis=1
的元素有 3 組, 需要遍歷它們求和
array([[ ------------------------> axis=1 的軸
[ 0, 1, 2, 3], |
+ |
[ 4, 5, 6, 7], | 把這三組數加起來, 有
+ | [12, 15, 18, 21]
[ 8, 9, 10, 11]], V
[
[12, 13, 14, 15], |
+ |
[16, 17, 18, 19], | 同上操作, 有
+ | [48, 51, 54, 57]
[20, 21, 22, 23]]]) V
可以看到, 切片出來元素再相加, 和我們用這種方式取出元素再相加, 它們本質上是相同的
現在再來看看 axis=2
的情況, 它們都是一致的
array([[
[ 0, 1, 2, 3], 左---->右, 遍歷求和, 有 [6]
[ 4, 5, 6, 7], 有 [22]
[ 8, 9, 10, 11]], 有 [38]
[
[12, 13, 14, 15], 有 [54]
[16, 17, 18, 19], 有 [70]
[20, 21, 22, 23] 有 [86]
]])
>>> a.sum(axis=2)
array([[ 6, 22, 38],
[54, 70, 86]])
# 列印切片
>>> a[:,:,0]
array([[ 0, 4, 8],
[12, 16, 20]])
>>> a[:,:,1]
array([[ 1, 5, 9],
[13, 17, 21]])
>>> a[:,:,2]
array([[ 2, 6, 10],
[14, 18, 22]])
>>> a[:,:,3]
array([[ 3, 7, 11],
[15, 19, 23]])
>>> a[:,:,0] + a[:,:,1] + a[:,:,2] + a[:,:,3] == a.sum(axis=2)
array([[ True, True, True],
[ True, True, True]])
切片相加, 也確實等於 sum(axis=2)
四維陣列
我們再來簡單驗證一下四維陣列
>>> a
array([[[ -----------------------------> axis=2 的軸
[ 0, 1, 2, 3, 4], |
[ 5, 6, 7, 8, 9], | 遍歷求和, 有
[ 10, 11, 12, 13, 14], | [ 30, 34, 38, 42, 46]
[ 15, 16, 17, 18, 19]], V
[[ 20, 21, 22, 23, 24], |
[ 25, 26, 27, 28, 29], | 同上, 有
[ 30, 31, 32, 33, 34], | [110, 114, 118, 122, 126]
[ 35, 36, 37, 38, 39]], V
[[ 40, 41, 42, 43, 44], |
[ 45, 46, 47, 48, 49], | 同上, 以下不再贅述
[ 50, 51, 52, 53, 54], |
[ 55, 56, 57, 58, 59]]], V
[[[ 60, 61, 62, 63, 64],
[ 65, 66, 67, 68, 69],
[ 70, 71, 72, 73, 74],
[ 75, 76, 77, 78, 79]],
[[ 80, 81, 82, 83, 84],
[ 85, 86, 87, 88, 89],
[ 90, 91, 92, 93, 94],
[ 95, 96, 97, 98, 99]],
[[100, 101, 102, 103, 104],
[105, 106, 107, 108, 109],
[110, 111, 112, 113, 114],
[115, 116, 117, 118, 119]]]])
>>> a.sum(axis=2)
array([[[ 30, 34, 38, 42, 46],
[110, 114, 118, 122, 126],
[190, 194, 198, 202, 206]],
[[270, 274, 278, 282, 286],
[350, 354, 358, 362, 366],
[430, 434, 438, 442, 446]]])
# 列印切片
>>> a[:,:,0,:]
array([[[ 0, 1, 2, 3, 4],
[ 20, 21, 22, 23, 24],
[ 40, 41, 42, 43, 44]],
[[ 60, 61, 62, 63, 64],
[ 80, 81, 82, 83, 84],
[100, 101, 102, 103, 104]]])
>>> a[:,:,1,:]
array([[[ 5, 6, 7, 8, 9],
[ 25, 26, 27, 28, 29],
[ 45, 46, 47, 48, 49]],
[[ 65, 66, 67, 68, 69],
[ 85, 86, 87, 88, 89],
[105, 106, 107, 108, 109]]])
>>> a[:,:,2,:]
array([[[ 10, 11, 12, 13, 14],
[ 30, 31, 32, 33, 34],
[ 50, 51, 52, 53, 54]],
[[ 70, 71, 72, 73, 74],
[ 90, 91, 92, 93, 94],
[110, 111, 112, 113, 114]]])
>>> a[:,:,3,:]
array([[[ 15, 16, 17, 18, 19],
[ 35, 36, 37, 38, 39],
[ 55, 56, 57, 58, 59]],
[[ 75, 76, 77, 78, 79],
[ 95, 96, 97, 98, 99],
[115, 116, 117, 118, 119]]])
>>> a[:,:,0,:] + a[:,:,1,:] + a[:,:,2,:] + a[:,:,3,:] == a.sum(axis=2)
array([[[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True]],
[[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True]]])
可以看到結果仍然是正確的, 後續五維、六維乃至更高維度, 也是如此
高維陣列的切片
高維陣列的第 axis
維度切片過程, 需要將高於 axis
的維度保持不動, 將低於它的維度看作整體 (或者說是子矩陣), 我們還是以四維陣列為例, 切片 a[:,1:2, :, 2:3]
, axis=1
為 1:2
, aixs=3
為 2:3
>>> a = np.arange(120).reshape(2,3,4,5)
array([ --------------------------------> axis=0, 全選
[[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[ 10, 11, 12, 13, 14],
[ 15, 16, 17, 18, 19]],
[ -----------------------------> axis=1, 選取
[ 20, 21, 22, 23, 24],
[ 25, 26, 27, 28, 29],
[ 30, 31, 32, 33, 34],
[ 35, 36, 37, 38, 39]],
[[ 40, 41, 42, 43, 44],
[ 45, 46, 47, 48, 49],
[ 50, 51, 52, 53, 54],
[ 55, 56, 57, 58, 59]]],
# 下面是 axis=0 的第二個子矩陣
[[[ 60, 61, 62, 63, 64],
[ 65, 66, 67, 68, 69],
[ 70, 71, 72, 73, 74],
[ 75, 76, 77, 78, 79]],
[ -----------------------------> axis=1, 選取
[ 80, 81, 82, 83, 84],
[ 85, 86, 87, 88, 89],
[ 90, 91, 92, 93, 94],
[ 95, 96, 97, 98, 99]],
[[100, 101, 102, 103, 104],
[105, 106, 107, 108, 109],
[110, 111, 112, 113, 114],
[115, 116, 117, 118, 119]]]])
axis=0
需要全選, 它的維度是 2, 所以這兩組子矩陣中, 都需要進行選取. axis=1
是 1:2
, 需要選擇第 1
到第 2
個子矩陣, 但不包含第 2
個, 得到
[[[ -----------------------------> axis=1, 選取
[ 20, 21, 22, 23, 24], ----> 選擇 [22]
[ 25, 26, 27, 28, 29], ----> 選擇 [27]
[ 30, 31, 32, 33, 34], ----> 選擇 [32]
[ 35, 36, 37, 38, 39]]],----> 選擇 [37]
[[ -----------------------------> axis=1, 選取
[ 80, 81, 82, 83, 84], ----> 同上, 不再贅述
[ 85, 86, 87, 88, 89],
[ 90, 91, 92, 93, 94],
[ 95, 96, 97, 98, 99]]]],
axis=2
是全選, 所以上面每一行都需要保留, 進入下一層選擇. axis=3
是 2:3
, 也就是選擇第 2
個元素, 最終的結果如下:
>>> a[:,1:2, :, 2:3]
array([[[[22],
[27],
[32],
[37]]],
[[[82],
[87],
[92],
[97]]]])
可以看到, 切片後的陣列, 仍然有 4 個軸, 切片不會降維