Numpy 高維空間中的軸

Asnelin發表於2024-03-01
Numpy 高維空間中的軸 (axis)

完成日期: 2024-03-01
更新日期: 2024-03-01

問題

Numpy 中有眾多操作會涉及到一個引數 axis, 也就是 . 這到底是什麼? 沿著某軸操作 (例如 np.sum(axis=0)) 又是什麼意思?

對於低維陣列, 或許可以按 來理解, 但如果上升到了四維、五維乃至更高, 就變得十分抽象了. 因為這裡要討論更高維度的情況, 所以就不使用 "行" 或 "列" 之類的詞語描述陣列的幾何意義了

一維陣列

我們先來看一維陣列的情況, 也就是 np.shape = (1,) 的情況 (這裡的等號 = 是數學上的符號, 不是賦值的意思), arange() 方法可以接受一個引數 n, 生成從 0 到 n 的整數序列

>>> a = np.arange(8)
>>> a
array([0, 1, 2, 3, 4, 5, 6, 7])

我們現在想要計算這個序列的和, 使用 np.sum(axis=0) 方法, 由於序列 a 只有一個維度, axis 引數只能為 0

>>> a
array([0, 1, 2, 3, 4, 5, 6, 7])
>>> a.shape
(8,)
>>> a.sum(axis=0)
28

輸出值是 28, np.shape 方法可以列印陣列的維度資訊, 這裡序列 a 只有一個, 這個維度有 8 個元素. axis=0 指在第 0 維上, 遍歷所有元素, 執行求和 sum() 的操作, 也就是把索引為 0, 1, 2, ..., 7 的元素累加起來, 因為它們的層級都是第 0 維, 得到 [28], 寫成虛擬碼的形式就是:

for i in range(8)
    sum = sum + i

Numpy 中 sum()min() 等方法還會有一個降維的效果, 也就是 reduce, 於是用於表示維度的方括號被抽離, [28] 變成 28, 成為 0 維的標量

二維陣列

axis=0

>>> a = np.arange(6).reshape(2,3)
>>> a.shape
(2, 3)
>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> a.sum(axis=0)
array([3, 5, 7])

reshape() 方法可以調整陣列維度, arange(6) 生成了有 6 個元素的一維陣列, reshape(2,3) 將它調整成 2*3 的二維陣列

在Numpy顯示多維陣列的方式中, 可以透過數方括號來確定維度, 第一個方括號是第 axis=0 維, 第二個方括號是第 axis=1 維. 我們調整一下上面 Numpy 顯示二維陣列的方式, 方便指示維度

>>> a
array([ -----------> 表示第 axis=0 維
        [  --------> 表示第 axis=1 維
         0, 1, 2],
        [3, 4, 5]])

那麼, axis=0 要求按第 0 軸去求和, 我們把陣列 a 再換個表示方式

>>>a
array([[0, 1, 2],
       [3, 4, 5]])

array([ A,
        B ])

這裡, 將 [0, 1, 2] 看作 A, 將 [3, 4, 5] 看作 B (還記得矩陣中的子矩陣嗎?)

那麼, 我們要遍歷 axis=0 軸的元素執行求和, 就是要計算 [A+B], 也就是 [[0, 1, 2] + [3, 4, 5]], 結果是 [[3, 5, 7]], 由於 np.sum() 會降維, 抽離最外邊第 axis=0 的方括號, 於是變成 [3, 5, 7]

如果寫成虛擬碼, 即使有 N 行

for i in range(A, B, C, ... , N)
    sum = sum + i

如果用切片的方式來表示, 是在計算 a[0, ...] + a[1, ...], 符號 ... 表示此維度不作指定, 會自動推導選擇全部

>>> a[0, ...]
array([0, 1, 2])

>>> a[1, ...]
array([3, 4, 5])

[0, 1, 2] + [3, 4, 5] = [3, 5, 7]

陣列 ashape(2,3), 執行完 sum(axis=0) 後, 變成 (3,), 第一個維度沒有了

axis=1

>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> a.shape
(2, 3)
>>> a.sum(axis=1)
array([ 3, 12])

如果 np.sum(axis=1), 我們需要遍歷同屬於 axis=1 層級的元素並求和

>>> a
array([
        [0, 1, 2],  左 ---> 右, 遍歷求和, 得 [3]
        [3, 4, 5]   左 ---> 右, 遍歷求和, 的 [12]
                 ])

array([
        [3],
        [12]
            ])

最終, 得到陣列 [[3], [12]], 由於降維, 抽離裡面 axis=1 的方括號, 變成 [3, 12]

如果用切片的方式檢視, 實質上是在計算 a[..., 0] + a[..., 1] + a[..., 2]

>>> a[..., 0]
array([0, 3])

>>> a[..., 1]
array([1, 4])

>>> a[..., 2]
array([2, 5])

[0, 3] + [1, 4] + [2, 5] = [3, 12]

陣列 ashape(2,3), 執行完 sum(axis=1) 後, 變成 (2,)

三維陣列

到這裡, 我們可以看到:

對於一維陣列, axis=0, 我們需要遍歷陣列中的每一個元素, 也就是 a[...]

對於二維陣列, axis=0, 我們需要遍歷陣列第 0 軸的每一個元素, 也就是 a[0, ...] + a[1, ...] + a[2, ...] ... a[n, ...]

如果 axis=1, 我們需要遍歷陣列第 1 軸的每一個元素, 也就是 a[..., 0] + a[..., 1] + a[..., 2] ... a[..., n]

那麼, 是不是可以認為, 對某一軸操作, 就是去遍歷那個軸的切片?

axis=0
for i in {a[0, :, :, ...], a[1, :, :, ...], ..., a[n, :, :, ...]}

axis=1
for i in {a[:, 0, :, ...], a[:, 1, :, ...], ..., a[:, n, :, ...]}

axis=2
for i in {a[:, :, 0, ...], a[:, :, 1, ...], ..., a[:, :, n, ...]}

...
...

我們用三維陣列驗證一下

>>> a = np.arange(24).reshape(2,3,4)
>>> a
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
>>> a.sum(axis=0)
array([[12, 14, 16, 18],
       [20, 22, 24, 26],
       [28, 30, 32, 34]])

# 列印切片
>>> a[0,...]
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
>>> a[1,...]
array([[12, 13, 14, 15],
       [16, 17, 18, 19],
       [20, 21, 22, 23]])

[[ 0,  1,  2,  3],     [[12, 13, 14, 15],       [[12, 14, 16, 18],
 [ 4,  5,  6,  7],  +   [16, 17, 18, 19],   =    [20, 22, 24, 26],
 [ 8,  9, 10, 11]]      [20, 21, 22, 23]]        [28, 30, 32, 34]]

同一層級 axis=0 的元素有兩個, 是切片 a[0, ...] 與 切片 a[1, ...], 它們的和正好就是 a.sum(axis=0)

現在我們在來看看 axis=1 的情況

>>> a
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])
>>> a.sum(axis=1)
array([[12, 15, 18, 21],
       [48, 51, 54, 57]])

# 列印切片
>>> a[:, 0, :]
array([[ 0,  1,  2,  3],
       [12, 13, 14, 15]])

>>> a[:, 1, :]
array([[ 4,  5,  6,  7],
       [16, 17, 18, 19]])

>>> a[:, 2, :]
array([[ 8,  9, 10, 11],
       [20, 21, 22, 23]])

>>> a[:,0,:] + a[:,1,:] + a[:,2,:] == a.sum(axis=1)
array([[ True,  True,  True,  True],
       [ True,  True,  True,  True]])

切片的和也確實等於 a.sum(axis=1), 回到陣列 a 本身, 第一個維度(也就是 axis=0) 不動, axis=1 的元素有 3 組, 需要遍歷它們求和

array([[  ------------------------> axis=1 的軸
        [ 0,  1,  2,  3],       |
                +               |
        [ 4,  5,  6,  7],       |   把這三組數加起來, 有
                +               |   [12, 15, 18, 21]
        [ 8,  9, 10, 11]],      V


       [
        [12, 13, 14, 15],       |
                +               |
        [16, 17, 18, 19],       |   同上操作, 有
                +               |   [48, 51, 54, 57]
        [20, 21, 22, 23]]])     V

可以看到, 切片出來元素再相加, 和我們用這種方式取出元素再相加, 它們本質上是相同的

現在再來看看 axis=2 的情況, 它們都是一致的

array([[
        [ 0,  1,  2,  3],  左---->右, 遍歷求和, 有 [6]
        [ 4,  5,  6,  7],                      有 [22]
        [ 8,  9, 10, 11]],                     有 [38]

       [
        [12, 13, 14, 15],       有 [54]
        [16, 17, 18, 19],       有 [70]
        [20, 21, 22, 23]        有 [86]
                        ]])
>>> a.sum(axis=2)
array([[ 6, 22, 38],
       [54, 70, 86]])

# 列印切片
>>> a[:,:,0]
array([[ 0,  4,  8],
       [12, 16, 20]])

>>> a[:,:,1]
array([[ 1,  5,  9],
       [13, 17, 21]])

>>> a[:,:,2]
array([[ 2,  6, 10],
       [14, 18, 22]])

>>> a[:,:,3]
array([[ 3,  7, 11],
       [15, 19, 23]])

>>> a[:,:,0] + a[:,:,1] + a[:,:,2] + a[:,:,3] == a.sum(axis=2)
array([[ True,  True,  True],
       [ True,  True,  True]])

切片相加, 也確實等於 sum(axis=2)

四維陣列

我們再來簡單驗證一下四維陣列

>>> a
array([[[  -----------------------------> axis=2 的軸
         [  0,   1,   2,   3,   4],   |
         [  5,   6,   7,   8,   9],   | 遍歷求和, 有
         [ 10,  11,  12,  13,  14],   | [ 30,  34,  38,  42,  46]
         [ 15,  16,  17,  18,  19]],  V

        [[ 20,  21,  22,  23,  24],   |
         [ 25,  26,  27,  28,  29],   | 同上, 有
         [ 30,  31,  32,  33,  34],   | [110, 114, 118, 122, 126]
         [ 35,  36,  37,  38,  39]],  V

        [[ 40,  41,  42,  43,  44],   |
         [ 45,  46,  47,  48,  49],   | 同上, 以下不再贅述
         [ 50,  51,  52,  53,  54],   |
         [ 55,  56,  57,  58,  59]]], V


       [[[ 60,  61,  62,  63,  64],
         [ 65,  66,  67,  68,  69],
         [ 70,  71,  72,  73,  74],
         [ 75,  76,  77,  78,  79]],

        [[ 80,  81,  82,  83,  84],
         [ 85,  86,  87,  88,  89],
         [ 90,  91,  92,  93,  94],
         [ 95,  96,  97,  98,  99]],

        [[100, 101, 102, 103, 104],
         [105, 106, 107, 108, 109],
         [110, 111, 112, 113, 114],
         [115, 116, 117, 118, 119]]]])
>>> a.sum(axis=2)
array([[[ 30,  34,  38,  42,  46],
        [110, 114, 118, 122, 126],
        [190, 194, 198, 202, 206]],

       [[270, 274, 278, 282, 286],
        [350, 354, 358, 362, 366],
        [430, 434, 438, 442, 446]]])

# 列印切片
>>> a[:,:,0,:]
array([[[  0,   1,   2,   3,   4],
        [ 20,  21,  22,  23,  24],
        [ 40,  41,  42,  43,  44]],

       [[ 60,  61,  62,  63,  64],
        [ 80,  81,  82,  83,  84],
        [100, 101, 102, 103, 104]]])
>>> a[:,:,1,:]
array([[[  5,   6,   7,   8,   9],
        [ 25,  26,  27,  28,  29],
        [ 45,  46,  47,  48,  49]],

       [[ 65,  66,  67,  68,  69],
        [ 85,  86,  87,  88,  89],
        [105, 106, 107, 108, 109]]])
>>> a[:,:,2,:]
array([[[ 10,  11,  12,  13,  14],
        [ 30,  31,  32,  33,  34],
        [ 50,  51,  52,  53,  54]],

       [[ 70,  71,  72,  73,  74],
        [ 90,  91,  92,  93,  94],
        [110, 111, 112, 113, 114]]])
>>> a[:,:,3,:]
array([[[ 15,  16,  17,  18,  19],
        [ 35,  36,  37,  38,  39],
        [ 55,  56,  57,  58,  59]],

       [[ 75,  76,  77,  78,  79],
        [ 95,  96,  97,  98,  99],
        [115, 116, 117, 118, 119]]])

>>> a[:,:,0,:] + a[:,:,1,:] + a[:,:,2,:] + a[:,:,3,:] == a.sum(axis=2)
array([[[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True]],

       [[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True]]])

可以看到結果仍然是正確的, 後續五維、六維乃至更高維度, 也是如此

高維陣列的切片

高維陣列的第 axis 維度切片過程, 需要將高於 axis 的維度保持不動, 將低於它的維度看作整體 (或者說是子矩陣), 我們還是以四維陣列為例, 切片 a[:,1:2, :, 2:3], axis=11:2, aixs=32:3

>>> a = np.arange(120).reshape(2,3,4,5)
array([ --------------------------------> axis=0, 全選
        [[[  0,   1,   2,   3,   4],
         [  5,   6,   7,   8,   9],
         [ 10,  11,  12,  13,  14],
         [ 15,  16,  17,  18,  19]],

        [  -----------------------------> axis=1, 選取
         [ 20,  21,  22,  23,  24],
         [ 25,  26,  27,  28,  29],
         [ 30,  31,  32,  33,  34],
         [ 35,  36,  37,  38,  39]],

        [[ 40,  41,  42,  43,  44],
         [ 45,  46,  47,  48,  49],
         [ 50,  51,  52,  53,  54],
         [ 55,  56,  57,  58,  59]]],

        # 下面是 axis=0 的第二個子矩陣
       [[[ 60,  61,  62,  63,  64],
         [ 65,  66,  67,  68,  69],
         [ 70,  71,  72,  73,  74],
         [ 75,  76,  77,  78,  79]],

        [  -----------------------------> axis=1, 選取
         [ 80,  81,  82,  83,  84],
         [ 85,  86,  87,  88,  89],
         [ 90,  91,  92,  93,  94],
         [ 95,  96,  97,  98,  99]],

        [[100, 101, 102, 103, 104],
         [105, 106, 107, 108, 109],
         [110, 111, 112, 113, 114],
         [115, 116, 117, 118, 119]]]])

axis=0 需要全選, 它的維度是 2, 所以這兩組子矩陣中, 都需要進行選取. axis=11:2, 需要選擇第 1 到第 2 個子矩陣, 但不包含第 2 個, 得到

[[[  -----------------------------> axis=1, 選取
    [ 20,  21,  22,  23,  24],  ----> 選擇 [22]
    [ 25,  26,  27,  28,  29],  ----> 選擇 [27]
    [ 30,  31,  32,  33,  34],  ----> 選擇 [32]
    [ 35,  36,  37,  38,  39]]],----> 選擇 [37]

[[  -----------------------------> axis=1, 選取
    [ 80,  81,  82,  83,  84],  ----> 同上, 不再贅述
    [ 85,  86,  87,  88,  89],
    [ 90,  91,  92,  93,  94],
    [ 95,  96,  97,  98,  99]]]],

axis=2 是全選, 所以上面每一行都需要保留, 進入下一層選擇. axis=32:3, 也就是選擇第 2 個元素, 最終的結果如下:

>>> a[:,1:2, :, 2:3]
array([[[[22],
         [27],
         [32],
         [37]]],


       [[[82],
         [87],
         [92],
         [97]]]])

可以看到, 切片後的陣列, 仍然有 4 個軸, 切片不會降維

相關文章