神經網路之反向傳播訓練(8行程式碼)

swensun發表於2018-05-23

多少行程式碼可以實現神經網路,並且瞭解反向傳播演算法。

import numpy as np
X = np.array([[0,0,1], [0,1,1], [1,0,1], [1,1,1]])
y = np.array([[0, 0, 1, 1]]).T
syn = 2 * np.random.random((3, 1)) - 1
for i in range(10000):
    l = 1 / (1 + np.exp(-np.dot(X, syn)))
    l_delta = (y - l) * (l * (1 - l))
    syn += np.dot(X.T, l_delta)
複製程式碼

如上,通過給定的輸入的輸出進行訓練,最後根據給定輸入預測輸出。

x1 x2 x3 y
0 0 1 0
1 1 1 1
1 0 1 1
0 1 1 0

訓練資料如上。


下面對程式碼進行解釋:
X:表示輸出資料
y: 表述輸出資料
syn: 3 * 1維的初始化隨機權重
網路結構如下所示:

神經網路之反向傳播訓練(8行程式碼)
如上,上部分為網路結構, 下部分為後面的輔助鏈式求導過程。


其過程大概為:
輸入資料x乘以對應的權重w,得到的值z經過非線性函式轉換(sigmoid)得到預測值l
計算l的目標值y的誤差。訓練的目標為將誤差最小化,因此可以通過調節輸入x和權重w。因為x為輸入資料,不可以調節。因此目標為:通過不斷的調節權重w,使其得到的預測值與目標的值的誤差最小(或者小於某個閾值),完成訓練。可以使用權重w對新的輸入進行預測。
其迭代過程為最後3行程式碼,稍作解釋:
第6行:輸入x乘以權重w,經過sigmoid函式變化得到l。 第7行:將誤差最小化的過程就是不斷調節w的過程。根據鏈式法則求loss關於w的導數,找到損失函式最快降低的方法。如上圖片所示:
loss關於w的導數 = loss關於l的導數 * l 關於z 的導數 * z 關於w的導數, 也即是最後兩行的內容。


下面我將擴充套件一下程式碼, 以便更好的理解:

import numpy as np
def sigmoid(z):
    return 1 / (1 + np.exp(-z))
def sigmoid_derivative(x):
    return x * (1 - x)

X = np.array([[0,0,1], [0,1,1], [1,0,1], [1,1,1]])
y = np.array([[0, 0, 1, 1]]).T

syn = 2 * np.random.random((3, 1)) - 1
for i in range(10000):
    l = sigmoid(np.dot(X, syn))
    loss = (y - l) ** 2
    if i % 1000 == 0:
        print("Loss: ", np.sum(loss))
    loss_derivative = 2  * (y - l)
    l_delta = loss_derivative * sigmoid_derivative(l)
    syn += np.dot(X.T, l_delta)
print("syn:\n" ,syn)
print("l:\n", l)
複製程式碼

下面是列印結果:

Loss:  0.8505452956411994
Loss:  0.0013461667714825178
Loss:  0.0006585827213748538
Loss:  0.0004348364571153864
Loss:  0.00032425670465056856
Loss:  0.0002583897080090336
Loss:  0.00021470251268552865
Loss:  0.00018361751185635365
Loss:  0.000160374621766103
Loss:  0.00014234163846899023
syn:
 [[10.38061079]
 [-0.20679655]
 [-4.98439294]]
l:
 [[0.00679775]
 [0.00553486]
 [0.99548654]
 [0.9944554 ]]
複製程式碼

可以看到隨著迭代次數的增加,loss變的越來越小,說明計算得到的值越接近真實值。 最後列印的權重syn和預測值l也證明了這一點。

後面打算加入隱藏層,提高神經網路的泛化能力。

參考資料:
A Neural Network in 11 lines of Python (Part 1)
深度學習之反向傳播演算法 上/下 Part 3 ver 0.9 beta

github原始碼

相關文章