章節安排
- 背景介紹
- 均方根誤差MSE
- 最小二乘法
- 梯度下降
- 程式設計實現
背景
生活中大多數系統的輸入輸出關係為線性函式,或者在一定範圍內可以近似為線性函式。在一些情形下,直接推斷輸入與輸出的關係是較為困難的。因此,我們會從大量的取樣資料中推導系統的輸入輸出關係。典型的單輸入單輸出線性系統可以用符號表示為:
其中,\(k\)為斜率,反應了當輸入量\(x\)變化時,輸出\(y\)的變化與輸入\(x\)變化的比值;\(b\)反應了當系統沒有輸入(或輸入為\(0\))時,系統的輸出值。
資料一般稱觀測資料或取樣資料,這兩種說法具有一定的側重點,觀測傾向於客觀系統,例如每天的漲潮水深;取樣傾向於主觀系統,例如,對彈簧施加10N的壓力,觀察彈簧的形變數。
對於但輸入單輸出系統,資料可以表示為:
或
其中符號\(O\)對應observation(觀測)、符號\(S\)對應sampling(取樣),\(\{o_i\}_N\)中\(o_i\)表示取樣序列中的每一個元素,\(N\)表示序列中元素的個數,\(x_i\)表示系統輸入,\(y_i\)表示系統輸出
在系統的推導過程中,一般稱推導的結果為對實際系統的估計或近似,用符號記為\(\hat{y}=\hat{f}(x)\)。對於單個取樣點,系統的誤差定義為:對該取樣輸入,輸出的真實值與輸出的預測值的差為誤差。用資料公式表示為:
對於整體取樣序列,一種經典的誤差是均方根誤差(Mean Squared Error, MSE),其數學公式為:
在推導系統輸入輸出關係,通常有兩種方法,一種是基於數值推導的方法,一種是基於學習的方法。本文分別以最小二乘法和梯度下降為例講解兩種方法。
MSE
對於單個取樣點的情形,MSE退化為方差的平方,即:
假定引數\(b\)為常量,僅考慮MSE與引數的關係,有
易得,MSE是關於\(k\)的二次函式,且該二次函式有唯一的零點:\(k_0=-(b-y)/x\)
對於多個點的情形,對每個點\(\{s_i\}=\{x_i,y_i\}\),\(\varepsilon_i^2\)均可表示為關於\(k\)的二次函式,有:
即:序列的MSE也為關於引數\(k\)的二次函式,並且,\(MSE\geq0\),當且僅當\((b-y_i)/x_i=M\)為常數時不等式取等。
可以很容易證明MSE也是關於引數\(b\)的二次函式
開口向上的二次函式有兩個重要的性質:
- 導數為\(0\)的點,為其最小值點。
- 任意點距離最小值點的距離與其導數值成正比,方向為導數方向的反方向
性質1、2分別是最小二乘法、梯度下降法的理論基礎/依據。
最小二乘法
最小二乘法基於MSE進行設計,其思想為,找到一組引數,使得MSE關於每個引數的偏導為0,對於一元輸入的情形,即:
首先化簡公式\((3.2)\)
由公式\((3.2)\)有:
其次化簡公式\(3.1\)
代入公式\((3.1),(3.3)\)有:
公式\((3.3),(3.4)\)即為最小二乘法的引數公式
梯度下降
對於學習機器學習的初學者,我們首先討論最簡單的情形:基於單個取樣點的學習。
二次函式具有重要性質:任意點距離最小值點的距離與其導數值成正比
基於該性質,我們可以可以設計引數更新公式如下
故有引數更新公式:
其中\(\lambda\)為學習率,一般取\(0.1\sim10^{-6}\)
常數\(2\)是可以預設的,可以視為學習率放大了兩倍。
程式設計實現
建議讀者按照如下方法建立標頭檔案、定義函式
typedef.h
:定義變數型別
random_point.h
:生成隨機點
least_square.h
:最小二乘法的實現
gradient_descent.h
:梯度下降方法的實現
型別定義
首先我們需要定義取樣點,以及取樣點序列型別。
取樣點是包含\(x\)、\(y\)兩個值的資料型別。同時,為方便使用,定義別名Point
取樣點序列,或者稱資料,可以儲存為型別為Point
的vector
struct SamplePoint{
float x;
float y;
}
using Point = SamplePoint;
using Data = std::vector<Point>;
對於直線,其包含\(k\),\(b\)兩個引數,同時,為了方便呼叫,定義括號運算子()
過載
struct LinearFunc{
float k;
float b;
float operator()(float x){
return k*x+b;
}
}
using Line = LinearFunc;
using Func = LinearFunc;
資料生成
採用random
庫中的normal_distribution
隨機數引擎
#include <random>
#include <cmath>
#include "typedef.h"
Data generatePoints(const Func& func, float sigma, float a, float b, int numPoints) {
Data points;
std::random_device rd;
std::mt19937 gen(rd());
// std::uniform_real_distribution<> distX(a, b); // 均勻分佈
std::normal_distribution<> distX((a + b) / 2, (b - a) / 2.8); // 正態分佈
std::normal_distribution<> distY(0, sigma);
for (int i = 0; i < numPoints; ++i) {
float x = distX(gen);
float y = func(x) + distY(gen);
points.push_back({ x, y });
}
return points;
}
該方法接受五個輸入,分別是:
func
:函式,自變數\(x\)與自變數\(y\)的關係sigma
:\(y\)的觀測值與真實值的誤差的方差a
、b
:生成的資料範圍的參考上下界,決定了生成資料的寬度,同時,絕大多數資料將位於此區間numPoints
:點的個數
最小二乘法
最小二乘法僅需接受一個輸入:資料Data
,同時返回資料。
在實現中,需要遍歷取樣資料,並分別進行累加計算\(\sum x_i\)、\(\sum y_i\)、\(\sum x_i^2\)和\(\sum x_iy_i\)
Line Least_Square(const Data& data) {
Line line;
float s_x = 0.0f;
float s_y = 0.0f;
float s_xx = 0.0f;
float s_xy = 0.0f;
float n = static_cast<float>(data.size());
for (const auto& p : data) {
s_x += p.x;
s_y += p.y;
s_xx += p.x * p.x;
s_xy += p.x * p.y;
}
line.k = (n * s_xy - s_x * s_y) / (n * s_xx - s_x * s_x);
line.b = (s_y - line.k * s_x) / n;
return line;
}
梯度下降
梯度下降法是一種學習方法。對引數的估計逐漸向最優估計靠近。在本例中表現為,MSE逐漸降低。
首先實現單步的迭代,在該過程中,遍歷所有的取樣資料,依據引數更新公式對引數進行修正。
梯度下降法需要一個給定的初值,對於線性函式,除了人工生成、隨機初值外,一種方式是,假定為正比例函式,以估計\(k\),假定為常函式,以估計\(b\),公式如下:
在本例中,設定為對初值進行100次迭代後得到最終估計,讀者可根據實際情況調整,在學習度設計的合適的情況下,一般迭代次數在\(50\sim200\)次
#include "typedef.h"
constexpr float eps = 1e-1;
constexpr float lambda = 1e-5;
void GD_step(Func& func, const Data& data) {
for (const auto& p : data) {
float error = func(p.x) - p.y;
func.k -= lambda * error * p.x;
func.b -= lambda * error;
}
}
Func Gradient_Descent(Func& func, const Data& data) {
float s_x = 0, s_y = 0;
for (const auto& p : data) {
s_x += p.x;
s_y += p.y;
}
Line line;
line.k = s_y / s_x;
line.b = s_y / data.size();
float lambda = 1e-5f;
for (size_t _ = 0; _ < 100; _++) {
GD_step(line, data);
}
return line;
}
附錄
nan問題
該問題有兩種產生的原因,引數更新符號錯誤及學習率過高。
引數更新符號錯誤
在更新公式中,如果錯誤的使用+號,或者採用\(\hat y-y\)計算\(\varepsilon_i\),都將會導致引數向誤差更大的方向更新,經過了數次迭代後,與真實值的距離越來越遠,最終產生nan。
學習率過高
如下圖,當學習率設定的過高時,新的引數組\(\{k_{t+1},b_{t+1}\}\)將比舊引數\(\{k_{t},b_{t}\}\)帶來更大的估計誤差(紅色箭頭),而良好的學習率是使得估計誤差逐漸下降的