opencv SVM的使用

brooknew發表於2019-10-30

參考 :https://blog.csdn.net/qq_35054151/article/details/81840935
https://blog.csdn.net/weixin_33698823/article/details/94496434 
C++ opencv SVM: https://www.cppentry.com/bencandy.php?fid=49&aid=153521&page=2 
引數優化:https://blog.csdn.net/computerme/article/details/38677599

講得透徹:https://blog.csdn.net/b285795298/article/details/81977271

 

/* SVM識別手寫數字mnist 0~9 */
/* 每個數字取80個樣本訓練(總80*10=800個),每個數字20個驗證(總20*10=200個),準確率為84%。*/
/* 每個數字取800個樣本訓練(總800*10=8000個),每個數字200個驗證(總200*10=2000個),準確率為91%。*/
/* alpha=1.0或1.0/255 對學習效果無影響 */
/* 樣本資料在 https://download.csdn.net/download/brooknew/11949332 */



#include "opencv2/highgui.hpp"
#include "opencv2/imgproc.hpp"
#include <string>
#include <iostream>
#include <fstream>
#include <vector>
#include "opencv2/ml.hpp"

using namespace std ;
using namespace cv ;
using namespace cv::ml ; 

double alpha = 1.0/255 ;
string trainpath = "D:\\python\\tensorflow\\ministRecognize\\trainImgByDigit\\train_list.txt" ;
string testpath= "D:\\python\\tensorflow\\ministRecognize\\trainImgByDigit\\test_list.txt" ;
const string modelFileName ="svm_model.xml" ;
 
/*
* 讀取樣本資料和標籤,輸出SVM的Mat格式
*/
void get_data(string path, Mat &trainData, Mat &trainLabels)
{
    fstream io(path, ios::in);
    if (!io.is_open()){
        cout << "file open error in path : " << path << endl;
        exit(0);
    }
 
    while (!io.eof())
    {
        string msg;
        io >> msg;
 
        trainData.push_back(imread(msg, 0).reshape(0, 1));
 
        io >> msg;
        int idx = msg[0] - '0';
        trainLabels.push_back(Mat(1, 1, CV_32S, &idx));
    }
 
	trainData.convertTo(trainData, CV_32F , alpha );
}
 
/*
* 訓練SVM
*/
void svm_train(Ptr<SVM> &model, Mat &trainData, Mat &trainLabels)
{
    model->setType(SVM::C_SVC);     //SVM型別
    model->setKernel(SVM::LINEAR);  //核函式,這裡使用線性核
	//model->setKernel(SVM::POLY) ;
    Ptr<TrainData> tData = TrainData::create(trainData, ROW_SAMPLE, trainLabels);
 
    cout << "SVM: start train ..." << endl;
    model->trainAuto(tData);
    cout << "SVM: train success ..." << endl;
}
 
/*
* 利用訓練好的SVM預測,以及計算準確率
*/
void svm_predict(Ptr<SVM> &model, Mat test, Mat testLabels )
{
    Mat result;
    float rst = model->predict(test, result);
	int good = 0 ;//準確的個數
    for (auto i = 0; i < result.rows; i++){
		if ( (int)result.at<float>(i, 0) == (int)testLabels.at<int>(i,0) ) {
			good ++ ;
		}else{
			;//cout <<"i=" << i <<"  "<< testLabels.at<int>(i,0) <<":" << result.at<float>(i, 0) << endl ; 
		}
        cout << result.at<float>(i, 0) << "  " ;
		if ( (i+1) % 10 == 0 )
			cout << endl ;
    }
	cout << "Right:" << good << "  Accurary rate: " << int((float)good/result.rows*100) << "%" << endl ; 
}
 
int usingSvm_main(int argc, char* argv[])
{
    string test_path = testpath; 
    string train_path = trainpath ;
 
    Ptr<SVM> model = SVM::create();
    Mat trainData, trainLabels;
    get_data(train_path, trainData, trainLabels);
    svm_train(model, trainData, trainLabels);
	model->save( modelFileName  ) ;//儲存模型
 
    Mat testData , testLabels;
    get_data(test_path, testData , testLabels );
    Ptr<SVM> modelV = SVM::load<SVM>( modelFileName ) ; //載入模型
	svm_predict(modelV, testData , testLabels  );
	while( true ) ;
	return 0;
}

建立訓練列表的程式碼:
 

import os
import shutil
KIND = 10
filesInEachSubDir = 1000 

def main():
    dir = 'D:/python/tensorflow/ministRecognize/trainImgByDigit/'
    with open( dir + 'test_list.txt' , 'wt' ) as f1 :    
        with open( dir + 'train_list.txt' , 'wt' ) as f :
            for i in range(KIND):
                subdir = dir + str(i) + '/'
                fn = os.listdir( subdir )
                nfil = min( filesInEachSubDir , len( fn ) )
                nfilTrain = int(nfil*0.8)  
                for nf in range( nfilTrain ) :
                    if (i == KIND-1) and ( nf == nfilTrain-1) :
                        s = subdir + fn[nf] + ' ' + str( i ) 
                    else:
                        s = subdir + fn[nf] + ' ' + str( i ) + '\n'
                    f.write( s )
                for nf in range( nfilTrain , nfil ) :
                    if (i == KIND-1) and ( nf == nfil-1) :
                        s = subdir + fn[nf]+ ' ' + str( i )
                    else :
                        s = subdir + fn[nf] + ' ' + str( i ) + '\n'
                    f1.write( s )

main() 

 

 

相關文章