下面這是opencv官方文件中的程式碼,我加了一部分註釋:
1 #include "stdafx.h" 2 #include "opencv2/core/core.hpp" 3 #include "highgui.h" 4 #include "ml.h" 5 6 using namespace cv; 7 8 int _tmain(int argc, _TCHAR* argv[]) 9 { 10 // 11 int width = 512, height = 512; 12 Mat image = Mat::zeros(height, width, CV_8UC3); 13 14 // set up training data 15 float labels[4] = {1.0, 1.0, -1.0, -1.0}; 16 Mat labelsMat(4, 1, CV_32FC1, labels); 17 18 float trainingData[4][2] = { {501, 10}, {255, 10}, {501, 255}, {10, 501} }; 19 Mat trainingDataMat(4, 2, CV_32FC1, trainingData); 20 21 // set up SVM's parameters,具體引數設定請看下文 22 CvSVMParams params; 23 params.svm_type = CvSVM::C_SVC; 24 params.kernel_type = CvSVM::LINEAR; 25 params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6); 26 27 // train the svm 28 CvSVM SVM; 29 SVM.train(trainingDataMat, labelsMat, Mat(), Mat(), params); 30 31 Vec3b green(0,255,0), blue(255,0,0); 32 33 // show the decision region given by the SVM 34 for (int i = 0; i < image.rows; ++ i) 35 { 36 for (int j = 0; j < image.cols; ++ j) 37 { 38 Mat sampleMat = (Mat_<float>(1,2) << i,j); 39 40 // predict 函式使用訓練好的SVM模型對一個輸入的樣本進行分類 41 float response = SVM.predict(sampleMat); 42 43 if (response == 1) 44 { 45 // 注意這裡是(j,i),不是(i,j) 46 image.at<Vec3b>(j,i) = green; 47 } 48 else 49 { 50 // 同上 51 image.at<Vec3b>(j,i) = blue; 52 } 53 } 54 } 55 56 int thickness = -1; 57 int lineType = 8; 58 59 circle(image, Point(501, 10), 5, Scalar( 0, 0, 0), thickness, lineType); 60 circle(image, Point(255, 10), 5, Scalar( 0, 0, 0), thickness, lineType); 61 circle(image, Point(501, 255), 5, Scalar(255,255,255), thickness, lineType); 62 circle(image, Point( 10, 501), 5, Scalar(255,255,255), thickness, lineType); 63 64 // show support vectors 65 thickness = 2; 66 lineType = 8; 67 68 // 獲得當前的支援向量的個數 69 int c = SVM.get_support_vector_count(); 70 71 for (int i = 0; i < c; ++ i) 72 { 73 const float* v = SVM.get_support_vector(i); 74 circle( image, Point( (int) v[0], (int) v[1]), 6, Scalar(128, 128, 128), thickness, lineType); 75 } 76 77 imwrite("result.png", image); // save the image 78 79 imshow("SVM Simple Example", image); // show it to the user 80 waitKey(0); 81 return 0; 82 }
這裡說一下CvSVMParams中的引數設定
1 CV_SVM 中的引數設定 2 3 svm_type: 4 CvSVM::C_SVC C-SVC 5 CvSVM::NU_SVC v-SVC 6 SvSVM::ONE_CLASS 一類SVM 7 CvSVM::EPS_SVR e-SVR 8 CvSVM::NU_SVR v-SVR 9 10 kernel_type: 11 CvSVM::LINEAR 線性:u*v 12 CvSVM::POLY 多項式(r*u'v + coef0)^degree 13 CvSVM::RBF RBF函式: exp(-r|u-v|^2) 14 CvSVM::SIGMOID sigmoid函式: tanh(r*u'v + coef0) 15 16 成員變數 17 degree: 針對多項式核函式degree的設定 18 gamma: 針對多項式/rbf/sigmoid核函式的設定 19 coef0: 針對多項式/sigmoid核函式的設定 20 Cvalue: 為損失函式,在C-SVC、e-SVR、v-SVR中有效 21 nu: 設定v-SVC、一類SVM和v-SVR引數 22 p: 為設定e-SVR中損失函式的值 23 class_weights: C_SVC的權重 24 term_crit: 為SVM訓練過程的終止條件。 25 其中預設值 degree = 0, 26 gamma = 1, 27 coef0 = 0, 28 Cvalue = 1, 29 nu = 0, 30 p = 0, 31 class_weights = 0