關於CS231N-Assignment1-KNN中no-loop矩陣乘法程式碼的講解

IAMoldpan發表於2017-10-26

在使用無迴圈的演算法進行計算距離的效率是很高的
可以看到No loop演算法使用的時間遠遠小於之前兩種演算法

Two loop version took 56.785069 seconds
One loop version took 136.449761 seconds
No loop version took 0.591535 seconds   #很快!

實現程式碼主要為以下這一段:
其中X為500×3072的矩陣(測試矩陣)
X_train為5000×3072的矩陣(訓練矩陣)
dists 為500×5000的矩陣(距離矩陣)
題中的目的就是將X中每一行的畫素數值與X_train中每一行的畫素數值(3072個)進行距離運算得出歐氏距離(L2)再儲存到dists中
核心公式

test_sum = np.sum(np.square(X), axis=1)  # num_test x 1
train_sum = np.sum(np.square(self.X_train), axis=1)  # num_train x 1
inner_product = np.dot(X, self.X_train.T)  # num_test x num_train
dists = np.sqrt(-2 * inner_product + test_sum.reshape(-1, 1) + train_sum)  # broadcast

公式講解:
假設現在有三個矩陣:A(X)、B(X_train)、C(dists )
將維數縮小以方便操作,稍微進行推導,就可以得出上面的公式了
推導過程如下:
這裡寫圖片描述

相關文章