opencv SVM 使用

我讓你懂懂發表於2017-12-26

SVM是一種分類器,下面通過手寫0-9數字識別對其進行以下介紹。
1.首先準備訓練使用的手寫字型
這裡寫圖片描述
這裡寫圖片描述
如圖所示,將手寫字型分類放在不同的資料夾。
2.讀取圖片

//每種數字個數
const int count[10] = {5923,6742,5958,6131,5842,5421,5918,6265,5851,5949};
    string filename = "shouxieziti/";
    vector<Mat> imgin;
    vector<int> number;
    int sum = 0;
    for(int i = 0; i < 10; i++){
        string s;
        stringstream ss;
        ss<<i;
        ss>>s;
        for(int j = 1; j < count[i]+1; j++){
            string s1;
            stringstream ss1;
            ss1<<j;
            ss1>>s1;
            if(j<10){
                s1 = s+"_0000"+s1;
            }else if(j < 100){
                s1 = s+"_000"+s1;
            }else if(j < 1000){
                s1 = s+"_00"+s1;
            }else{
                s1 = s+"_0"+s1;
            }
            string in = filename + s + "/" + s1 +".jpg";
            Mat img = imread(in,IMREAD_GRAYSCALE);
//            imshow(in,img);
            imgin.push_back(img);
            number.push_back(i);
            cout<<in<<"  ok"<<" "<<img.channels()<<" "<<number[sum + j - 1]<<" "<<sum<<endl;

        }
        sum += count[i];
    }
    cout<<imgin.size()<<" "<<imgin[0].size()<<"have been read"<<endl;

圖片資訊的讀取由自己的儲存方式進行。
3.生成opencv中SVM需要的形式

    Mat imgtrain((int)imgin.size(), 28*28, CV_32FC1);
    Mat imglabel((int)imgin.size(), 1, CV_32SC1);
//    cout<<imgtrain.channels()<<" "<<imglabel.channels()<<endl;
    cout<<"creat train data..."<<endl;
    for(int i = 0; i < (int)imgin.size(); i++){
        Mat_<float>::iterator trainbegin = imgtrain.begin<float>() + 28*28*i;
        Mat_<int>::iterator labelbegin = imglabel.begin<int>();
        Mat_<uchar>::iterator inbegin = imgin[i].begin<uchar>();
        for(int j = 0; j < 28*28; j++){
            float data = (float)*(inbegin+j);
            *(trainbegin+j) = (data+0.0)/255.0;
//            if(data > 200){
//                cout<<*(trainbegin+j)<<" "<<*(labelbegin+j);
//            }
        }
        *(labelbegin+i) = number[i];
        cout<<*(labelbegin+i)<<" ";
    }

其中訓練資料是CV_32FC1型別;label資料是CV_32SC1型別。
另外,需要將資料進行歸一化,因為讀取的是灰度圖0-255範圍之內,所以我們將每個資料除以255就可以得到0-1之間的資料。
4.利用SVM進行訓練

    //設定SVM引數
    Ptr<ml::SVM> svm = ml::SVM::create();
    svm->setType(ml::SVM::C_SVC);
    svm->setKernel(ml::SVM::RBF);
    svm->setGamma(0.01);
    svm->setC(10.0);
    svm->setTermCriteria(TermCriteria(CV_TERMCRIT_ITER, 1000,FLT_EPSILON));
    //進行訓練
        cout<<"trainning..."<<endl;
    bool f = svm->train(imgtrain,ml::ROW_SAMPLE,imglabel);

//    Ptr<ml::TrainData> traindata = ml::TrainData::create(imgtrain,ml::ROW_SAMPLE,imglabel);
//    bool f = svm->trainAuto(traindata, 10);
//    cout<<f<<endl;
    //儲存訓練好的資料
    cout<<"saving..."<<endl;
    svm->save("train1.xml");
    cout<<"save done..."<<endl;

5.讀取生成的train1.xml進行預測

Ptr<ml::SVM> svm = ml::StatModel::load<ml::SVM>("train1.xml");
cout<<"predicting..."<<endl;
    vector<float> result;
    int right = 0, wrong = 0;
    Mat_<int>::iterator labelbegin = imglabel.begin<int>();
    for(int i = 0; i < (int)imgtrain.rows; i++){
        Mat sample = imgtrain.row(i);
        result.push_back(svm->predict(sample));
        cout<<result[i]<<endl;
        if(abs(result[i] - *(labelbegin+i)) < 0.001){
            right++;
        }else{
            wrong++;
        }
    }
    cout<<"predict done... "<<right<<" right "<<wrong<<" wrong"<<endl;
    cout<<"right rate "<<(float)right/(float)(right+wrong)<<endl;
    cout<<"wrong rate "<<(float)wrong/(float)(right+wrong)<<endl;

6.通過訓練60000個樣本,能實現非常高的正確率。下圖是識別了10000個測試資料的結果
這裡寫圖片描述
7.補充
學習過程中主要參考瞭如下連結:
https://www.cnblogs.com/cheermyang/p/5624333.html
手寫字型是由mnist手寫字型影像資料庫生成的,參考下列連結:
http://m.blog.csdn.net/fengbingchun/article/details/49611549

相關文章