統計學習方法c++實現之一 感知機

bobxxxl發表於2018-12-14

感知機

2018/12/17 程式碼結構更新,詳見https://github.com/bBobxx/statistical-learning

前言

最近學習了c++,俗話說‘光說不練假把式’,所以決定用c++將《統計學習方法》裡面的經典模型全部實現一下,程式碼在這裡,請大家多多指教。

感知機雖然簡單,但是他可以為學習其他模型提供基礎,現在先簡單回顧一下基礎知識。

感知機模型

感知機

首先,感知機是用來分類的模型,上圖就是簡單的感知機模型,其中(f) 我們一般取符號函式

[sign(x)=egin{cases} -1,quad x<0 \\ +1,quad xgeq0 end{cases} ]

所以感知機的數學形式就是

[y=sign(wx+b)]

其中w和x都是n維的向量。當n為2時,(sign)裡面的公式有沒有特別熟悉?就是直線的公式,n>2就是超平面,用一下課本里面的圖就是如下圖

統計學習方法c++實現之一 感知機

這就是分類的根據,必須要注意,感知機只能分離線性可分資料,非線性的不行。

感知機學習策略

提到學習就不得不提到梯度下降演算法。感知機的學習策略就是隨機梯度下降演算法。

具體的在書中講的很詳細,我這裡就不贅述了,直接看學習演算法吧:

(1) 選取初值w,b。

(2) 選取一組訓練資料(x, y)。

(3) 如果(y(wx+b)leq0),則

[ w += lr*yx]

[b+=lr*y]

(4)轉至(2)直到沒有誤分類點。

c++實現感知機

程式碼結構

統計學習方法c++實現之一 感知機

實現

首先我有一個基類Base,為了以後的演算法繼承用的,它包含一個run()的純虛擬函式,這樣以後就可以在main裡面實現多型。

我的資料都儲存在私有成員裡:

    std::vector<std::vector<double>> inData;//從檔案都的資料
    std::vector<std::vector<double>> trainData;//分割後的訓練資料,裡面包含真值
    std::vector<std::vector<double>> testData;
    unsigned long indim = 0;
    std::vector<double> w;
    double b;
    std::vector<std::vector<double>> trainDataF;//真正的訓練資料,特徵
    std::vector<std::vector<double>> testDataF;
    std::vector<double> trainDataGT;//真值
    std::vector<double> testDataGT;

在main函式裡只需要呼叫每個模型的run()方法,宣告的是基類指標:

int main() {
    Base* obj = new Perceptron();
    obj->run();
    delete obj;
    return 0;
}

第一步,讀取資料並分割。這裡用的vector儲存。

    getData("../data/perceptrondata.txt");
    splitData(0.6);//below is split data , and store it in trainData, testData

第二步初始化

    std::vector<double> init = {1.0,1.0,1.0};
    initialize(init);

第三步進行訓練。

在訓練時,函式呼叫順序如下:

  • 呼叫computeGradient,進行梯度的計算。對於滿足(y(wx+b)>0)的資料我們把梯度設為0。

    std::pair<std::vector<double>, double> Perceptron::computeGradient(const std::vector<double>& inputData, const double& groundTruth) {
        double lossVal = loss(inputData, groundTruth);
        std::vector<double> w;
        double b;
        if (lossVal > 0.0)
        {
            for(auto indata:inputData) {
                w.push_back(indata*groundTruth);
            }
            b = groundTruth;
        }
        else{
            for(auto indata:inputData) {
                w.push_back(0.0);
            }
            b = 0.0;
        }
        return std::pair<std::vector<double>, double>(w, b);//here, for understandable, we use pair to represent w and b.
                               //you also could return a vector which contains w and b.
    }

    在呼叫computeGradient時又呼叫了loss,即計算(-y(wx+b)),loss裡呼叫了inference,用來計算(wx+b),看起來有點多餘對吧,inference函式存在的目的是為了後面預測時候用的。

    double Perceptron::loss(const std::vector<double>& inputData, const double& groundTruth){
        double loss = -1.0 * groundTruth * inference(inputData);
        std::cout<<"loss is "<< loss <<std::endl;
        return loss;
    }
    double Perceptron::inference(const std::vector<double>& inputData){
      //just compute wx+b , for compute loss and predict.
      if (inputData.size()!=indim){
          std::cout<<"input dimension is incorrect. "<<std::endl;
          throw inputData.size();
      }
      double sum_tem = 0.0;
      sum_tem = inputData * w;
      sum_tem += b;
      return sum_tem;
    }
    
  • 根據計算的梯度更新w, b

    void Perceptron::train(const int & step, const float & lr) {
        int count = 0;
        createFeatureGt();
        for(int i=0; i<step; ++i){
            if (count==trainDataF.size()-1)
                count = 0;
            count++;
            std::vector<double> inputData = trainDataF[count];
            double groundTruth = trainDataGT[count];
            auto grad = computeGradient(inputData, groundTruth);
            auto grad_w = grad.first;
            double grad_b = grad.second;
            for (int j=0; j<indim;++j){//這裡更新引數
                w[j] += lr * (grad_w[j]);
            }
            b += lr * (grad_b);
        }
    }
    
  • 預測用的資料也是之前就分割好的,注意這裡的引數始終存在

std::vector<double> paraData; 

進行預測的程式碼

int Perceptron::predict(const std::vector<double>& inputData, const double& GT) {

    double out = inference(inputData);
    std::cout<<"The right class is "<<GT<<std::endl;
    if(out>=0.0){
        std::cout<<"The predict class is 1"<<std::endl;
        return 1;
    }
    else{
        std::cout<<"The right class is -1"<<std::endl;
        return -1;
    }

相關文章