迴歸分析用於分析輸入變數和輸出變數之間的一種關係,其中線性迴歸是最簡單的一種。
設: Y=wX+b,現已知一組X(輸入)和Y(輸出)的值,要求出w和b的值。
舉個例子:快年底了,銷售部門要發年終獎了,銷售員小王想知道今年能拿多少年終獎,目前他大抵知道年終獎是和銷售額(特徵量)掛鉤的,具體什麼規則不清楚,那麼他大概有兩個方法解決這個問題:
1、去問老闆,今年的分配規則是什麼。【通過演算法解決問題】
2、去向同事打聽他們的銷售額和獎金情況,然後推算自己能拿多少。【通過資料解決問題】
我們當然選擇第二種方法了。通過收集資料,我們得到下面這個表格:
拿到這個資料,我們基本上很快就能推算出兩者的對應關係,如果推算不出來,我們也可以繪製下面這張圖表:
通過圖表,我們可以立即看出兩者的對應關係了。
以上就是一個典型的線性迴歸求解的問題,下面我們要用TensorFlow框架解決這個問題。
具體解決思路如下:
1、先設w=1,b=0
2、取得一批訓練資料,將X代入函式f(x)=wx+b,計算取得在當前條件下的預測值Y‘
3、計算預測值Y‘和實際值Y的誤差
4、根據梯度對w、b進行微調
5、重複上述步驟,直到誤差值足夠小。
先貼出全部程式碼,然後再逐一解釋。
public class LinearRegression { public void Run() { // Supper Parameters float learning_rate = 0.01f; var W = tf.Variable<float>(1); var b = tf.Variable<float>(0); int epochs = 30; int steps = 100; Tensor loss = null; for (int epoch = 0; epoch < epochs; epoch++) { for (int step = 0; step < steps; step++) { int batch_size = 10; (NDArray train_X, NDArray train_Y) = LoadBatchData(batch_size); using (var g = tf.GradientTape()) { //通過當前引數計算預測值 var pred_y = W * train_X + b; //計算預測值和實際值的誤差 loss = tf.reduce_sum(tf.pow(pred_y - train_Y, 2)) / batch_size; //計算梯度 var gradients = g.gradient(loss, (W, b)); //更新引數 W.assign_sub(learning_rate * gradients.Item1); b.assign_sub(learning_rate * gradients.Item2); } } Console.WriteLine($"Epoch{epoch + 1}: loss = {loss.numpy()}; W={W.numpy()},b={b.numpy()}"); } } public (NDArray, NDArray) LoadBatchData(int n_samples) { float w = 0.02f; float b = 1.0f; NDArray train_X = np.arange<float>(start: 1, end: n_samples + 1); NDArray train_Y = train_X * w + b; return (train_X, train_Y); } }
下面對程式碼進行簡單的解釋:
首先,我們要讀取一批(比如10組 )訓練資料,標記為:train_X和train_Y,然後通過現有的w和b值計算預測值:pred_Y=w*train_X_b,此時train_X、train_Y、pred_Y都是10個資料長度的陣列。
然後計算預測資料和時間資料之間的誤差,我們採用均方誤差公式來計算:
然後開始計算W、b對於loss函式的梯度,梯度表達的就是W、b的變化對計算結果的影響,比如將W增大一點,loss的計算結果是變大還是變小,我們的目標是希望loss的值最小,如果w變大時loss變大(梯度為正數),那麼我們下一次就將w變小一點,反之同理。
這裡的learning_rate表示學習率,表示每次引數進行調整的步進值,就是每次調整一大步,還是一小步。通過多次的迴圈調整,w和b的值將調整為一個合適的數字,此時loss的值將會很小,線性迴歸就完成了。以下是運算結果:
在上述過程中,最難理解的就是梯度,以及如何計算梯度的問題,想要進一步瞭解的話可以參閱相關參考資料。
【相關資源】
原始碼:Git: https://gitee.com/seabluescn/tf_not.git
專案名稱:LinearRegression
【參考資料】
《深度學習入門:基於Python的理論與實踐(齋藤康毅)》,網上可以找到電子版