從手寫三層迴圈到標準實現,矩陣相乘執行效率提高三萬六千倍之路

MCTW發表於2021-01-22

前言

矩陣乘法可以說是最常見的運算之一。

本文介紹不同的方式實現的矩陣乘法,並比較它們執行速度的差異。

表示矩陣的方式有很多種,完善的矩陣類應該實現切片取值,獲得矩陣形狀等操作,但本文並不打算直接從原生Python實現一個矩陣類,而是直接用 Pytorch中的tensor表示矩陣。

開始: 三層迴圈

根據矩陣相乘定義,可通過三層迴圈實現該運算。

def matmul(a, b):
    r1, c1 = a.shape
    r2, c2 = b.shape
    
    assert c1 == r2
    
    rst = torch.zeros(r1, c2)
    
    for i in range(r1):
        for j in range(c2):
            for k in range(c1):
                rst[i][j] += a[i][k] * b[k][j]
    return rst

那麼這個函式的執行效率如何呢?讓我們嘗試兩個較大的矩陣相乘,測試一下執行時間。

m1 = torch.randn(5, 784)
m2 = torch.randn(784, 10)

%timeit -n 10 matmul(m1, m2)

得到結果如下:

624 ms ± 3.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

即每次矩陣相乘, 需要耗時 600ms 左右,這是一個非常非常慢的速度,慢到兩次矩陣乘法居然要耗時1秒多,這是不可能被接受的。

相同形狀的張量進行運算

如果兩個張量的形狀相同,則他們的運算為同一位置的數字進行運算。

a = torch.tensor([1., 2, 3])
b = torch.tensor([4., 5, 6])

a + b  # tensor([5., 7., 9.])
a * b  # tensor([ 4., 10., 18.])

康康之前用三層迴圈實現的矩陣相乘,發現最裡面一層迴圈的本質就是兩個同樣大小的張量相乘,再進行求和。
即第一個矩陣中的一行 跟 第二個矩陣中的一列 進行運算,且這行和列中的元素個數相同,則我們可以通過同樣形狀的張量運算改寫最內層迴圈:

def matmul(a, b):
    r1, c1 = a.shape
    r2, c2 = b.shape
    
    assert c1 == r2
    
    rst = torch.zeros(r1, c2)
    
    for i in range(r1):
        for j in range(c2):
            rst[i][j] = (a[i,:] * b[:,j]).sum()  # 改了這裡
    return rst

%timeit -n 10 matmul(m1, m2)

得到結果如下

1.4 ms ± 92.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

624 / 1.4=445,只改寫了一下最內層迴圈,就使得矩陣乘法快了445倍!

廣播機制

廣播機制使得不同形狀的張量間可以進行運算:

  1. 兩個張量擴充成同樣的形狀
  2. 再按相同形狀的張量進行運算
# shape: [2, 3]
a = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
])

# shape: [1]
b = torch.tensor([1])

# shape: [3]
c = torch.tensor([10, 20, 30])

形狀為 [2, 3] 和 [1] 的兩個張量相加:

a + b

"""輸出:
tensor([[2, 3, 4],
        [5, 6, 7]])
"""

形狀為 [2, 3] 和 [3] 的兩個張量相加:

b + c

"""輸出:
tensor([[11, 22, 33],
        [14, 25, 36]])
"""

這兩個例子中,維度低的張量都是暗地裡先擴充成了維度高的張量,然後再參與的運算。

那麼如何檢視擴充後的張量是啥呢?用 expand_as 函式就可以檢視:

b.expand_as(a)

"""輸出
tensor([[1, 1, 1],
        [1, 1, 1]])
"""
b.expand_as(a)

"""輸出
tensor([[10, 20, 30],
        [10, 20, 30]])
"""

這就一目瞭然了,形狀不同的張量可以通過廣播機制擴充成形狀一致的張量再進行運算。

那麼任意形狀的兩個張量都可以運算嗎?當然不是了,判斷兩個張量是否能運算的規則如下:

先從兩個張量的最後一個維度看起,如果維度的維數相同,或者其中一個維數為1,則可以繼續判斷,否則就失敗。
然後看倒數第二個維度,倒數第三個維數,一直到遍歷完某個張量的維數為止,一直沒有失敗則這兩個張量可以通過廣播機制進行運算。

那麼這個廣播機制和矩陣乘法有什麼關係呢?答案就是它可以幫我們再去掉一層迴圈。

現在的最記憶體迴圈的本質是 一個形狀為 [c1] 的張量 和 一個形狀為 [c1, c2] 的張量做運算,最終生成一個形狀為 [c2] 的張量。

則我們可以把矩陣運算改寫為:

def matmul(a, b):
    r1, c1 = a.shape
    r2, c2 = b.shape
    
    assert c1 == r2
    
    rst = torch.zeros(r1, c2)
    
    for i in range(r1):
        rst[i] = (a[i, :].unsqueeze(-1) * b).sum(0)
    return rst

%timeit -n 10 matmul(m1, m2)

"""輸出
249 µs ± 66.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""

現在已經把每次矩陣運算的時間壓縮到了 249 µs!!!,比最開始的 624ms 快了 2500倍!

對於 unsqueeze 操作不太熟悉的小夥伴請看我的另一篇文件: Pytorch 中張量的理解

但是還沒結束。。。因為兩個矩陣的相乘,就是 [r1, c1] 和 [c1, c2] 兩個張量的運算,我們可以直接把它用廣播機制一次到位的算出結果,連唯一的那層迴圈也可以省去:

def matmul(a, b):
    r1, c1 = a.shape
    r2, c2 = b.shape
    
    assert c1 == r2
    
    return (a.unsqueeze(-1) * b.unsqueeze(0)).sum(1)

%timeit -n 10 matmul(m1, m2)

"""輸出:
169 µs ± 41.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""

這個 169µs 已經是最開始矩陣相乘版本的 3700 倍了。。( Ĭ ^ Ĭ )淚目,果然知識是第一生產力。

愛因斯坦求和

接下來就是 pytorch 自帶的矩陣運算工具了,其中一個是愛因斯坦求和,貌似知道這個的同學不多。。
簡單來說,它能讓我們幾乎不編寫程式碼就能進行矩陣運算,只需要確定輸入和輸出矩陣的形狀即可:

def matmul(a, b):
    return torch.einsum("ik,kj->ij", a, b)

%timeit -n 10 matmul(a, b)

"""輸出
74 µs ± 25.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""

74µs 這個速度已經是原始版本的 8000 多倍了。。。但是對於工業級別的要求似乎仍然不夠快~

pytorch 的矩陣相乘標準實現

最後祭出 pytorch 的矩陣相乘官方版本:

def matmul(a, b):
    return a @ b

%timeit -n 10 matmul(m1, m2)

"""輸出
17.1 µs ± 28.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""

17.1 µs 是原始三層迴圈版本的 36000 倍,官方實現就是這麼簡單枯燥,樸實無華~

相關文章