利用Hog特徵和SVM分類器進行行人檢測

馬衛飛發表於2017-12-12

之前介紹過Hog特徵(http://blog.csdn.net/carson2005/article/details/7782726),也介紹過SVM分類器(http://blog.csdn.net/carson2005/article/details/6453502 );而本文的目的在於介紹利用Hog特徵和SVM分類器來進行行人檢測。

        在2005CVPR上,來自法國的研究人員Navneet Dalal Bill Triggs提出利用Hog進行特徵提取,利用線性SVM作為分類器,從而實現行人檢測。而這兩位也通過大量的測試發現,Hog+SVM是速度和效果綜合平衡效能較好的一種行人檢測方法。後來,雖然很多研究人員也提出了很多改進的行人檢測演算法,但基本都以該演算法為基礎框架。因此,Hog+SVM也成為一個里程錶式的演算法被寫入到OpenCV中。在OpenCV2.0之後的版本,都有Hog特徵描述運算元的API,而至於SVM,早在OpenCV1.0版本就已經整合進去了;OpenCV雖然提供了HogSVMAPI,也提供了行人檢測的sample,遺憾的是,OpenCV並沒有提供樣本訓練的sample。這也就意味著,很多人只能用OpenCV自帶的已經訓練好的分類器來進行行人檢測。然而,OpenCV自帶的分類器是利用Navneet DalalBill Triggs提供的樣本進行訓練的,不見得能適用於你的應用場合。因此,針對你的特定應用場景,很有必要進行重新訓練得到適合你的分類器。本文的目的,正在於此。

重新訓練行人檢測的流程:

(1)準備訓練樣本集合;包括正樣本集和負樣本集;根據機器學習的基礎知識我們知道,要利用機器學習演算法進行樣本訓練,從而得到一個效能優良的分類器,訓練樣本應該是無限多的,而且訓練樣本應該覆蓋實際應用過程中可能發生的各種情況。(很多朋友,用10來個正樣本,10來個負樣本進行訓練,之後,就進行測試,發現效果沒有想象中的那麼好,就開始發牢騷,抱怨。。。對於這些人,我只能抱歉的說,對於機器學習、模式識別的認識,你還處於沒有入門的階段);實際應用過程中,訓練樣本不可能無限多,但無論如何,三五千個正樣本,三五千個負樣本,應該不是什麼難事吧?(如果連這個都做不到,建議你別搞機器學習,模式識別了;訓練素材都沒有,怎麼讓機器學習到足夠的資訊呢?)

(2)收集到足夠的訓練樣本之後,你需要手動裁剪樣本。例如,你想用Hog+SVM來對商業步行街的監控畫面中進行行人檢測,那麼,你就應該用收集到的訓練樣本集合,手動裁剪畫面中的行人(可以寫個簡單程式,只需要滑鼠框選一下,就將框選區域儲存下來)。

(3)裁剪得到訓練樣本之後,將所有正樣本放在一個資料夾中;將所有負樣本放在另一個資料夾中;並將所有訓練樣本縮放到同樣的尺寸大小。OpenCV自帶的例子在訓練時,就是將樣本縮放為64*128進行訓練的;

(4)提取所有正樣本的Hog特徵;

(5)提取所有負樣本的Hog特徵;

(6)對所有正負樣本賦予樣本標籤;例如,所有正樣本標記為1,所有負樣本標記為0

(7)將正負樣本的Hog特徵,正負樣本的標籤,都輸入到SVM中進行訓練;Dalal在論文中考慮到速度問題,建議採用線性SVM進行訓練。這裡,不妨也採用線性SVM

(8)SVM訓練之後,將結果儲存為文字檔案。

(9)線性SVM進行訓練之後得到的文字檔案裡面,有一個陣列,叫做support vector,還有一個陣列,叫做alpha,有一個浮點數,叫做rho;alpha矩陣同support vector相乘,注意,alpha*supportVector,將得到一個列向量。之後,再該列向量的最後新增一個元素rho。如此,變得到了一個分類器,利用該分類器,直接替換opencv中行人檢測預設的那個分類器(cv::HOGDescriptor::setSVMDetector()),就可以利用你的訓練樣本訓練出來的分類器進行行人檢測了。

下面給出樣本訓練的參考程式碼:

[cpp] view plain copy
  1. class Mysvm: public CvSVM  
  2. {  
  3. public:  
  4.     int get_alpha_count()  
  5.     {  
  6.         return this->sv_total;  
  7.     }  
  8.   
  9.     int get_sv_dim()  
  10.     {  
  11.         return this->var_all;  
  12.     }  
  13.   
  14.     int get_sv_count()  
  15.     {  
  16.         return this->decision_func->sv_count;  
  17.     }  
  18.   
  19.     double* get_alpha()  
  20.     {  
  21.         return this->decision_func->alpha;  
  22.     }  
  23.   
  24.     float** get_sv()  
  25.     {  
  26.         return this->sv;  
  27.     }  
  28.   
  29.     float get_rho()  
  30.     {  
  31.         return this->decision_func->rho;  
  32.     }  
  33. };  
  34.   
  35. void Train()  
  36. {  
  37.     char classifierSavePath[256] = "c:/pedestrianDetect-peopleFlow.txt";  
  38.   
  39.     string positivePath = "E:\\pictures\\train1\\pos\\";  
  40.     string negativePath = "E:\\pictures\\train1\\neg\\";  
  41.   
  42.     int positiveSampleCount = 4900;  
  43.     int negativeSampleCount = 6192;  
  44.     int totalSampleCount = positiveSampleCount + negativeSampleCount;  
  45.   
  46.     cout<<"//////////////////////////////////////////////////////////////////"<<endl;  
  47.     cout<<"totalSampleCount: "<<totalSampleCount<<endl;  
  48.     cout<<"positiveSampleCount: "<<positiveSampleCount<<endl;  
  49.     cout<<"negativeSampleCount: "<<negativeSampleCount<<endl;  
  50.   
  51.     CvMat *sampleFeaturesMat = cvCreateMat(totalSampleCount , 1764, CV_32FC1);  
  52.     //64*128的訓練樣本,該矩陣將是totalSample*3780,64*64的訓練樣本,該矩陣將是totalSample*1764  
  53.     cvSetZero(sampleFeaturesMat);    
  54.     CvMat *sampleLabelMat = cvCreateMat(totalSampleCount, 1, CV_32FC1);//樣本標識    
  55.     cvSetZero(sampleLabelMat);    
  56.   
  57.     cout<<"************************************************************"<<endl;  
  58.     cout<<"start to training positive samples..."<<endl;  
  59.   
  60.     char positiveImgName[256];  
  61.     string path;  
  62.     for(int i=0; i<positiveSampleCount; i++)    
  63.     {    
  64.         memset(positiveImgName, '\0', 256*sizeof(char));  
  65.         sprintf(positiveImgName, "%d.jpg", i);  
  66.         int len = strlen(positiveImgName);  
  67.         string tempStr = positiveImgName;  
  68.         path = positivePath + tempStr;  
  69.   
  70.         cv::Mat img = cv::imread(path);  
  71.         if( img.data == NULL )  
  72.         {  
  73.             cout<<"positive image sample load error: "<<i<<" "<<path<<endl;  
  74.             system("pause");  
  75.             continue;  
  76.         }  
  77.   
  78.         cv::HOGDescriptor hog(cv::Size(64,64), cv::Size(16,16), cv::Size(8,8), cv::Size(8,8), 9);  
  79.         vector<float> featureVec;   
  80.   
  81.         hog.compute(img, featureVec, cv::Size(8,8));    
  82.         int featureVecSize = featureVec.size();  
  83.   
  84.         for (int j=0; j<featureVecSize; j++)    
  85.         {         
  86.             CV_MAT_ELEM( *sampleFeaturesMat, float, i, j ) = featureVec[j];   
  87.         }    
  88.         sampleLabelMat->data.fl[i] = 1;  
  89.     }  
  90.     cout<<"end of training for positive samples..."<<endl;  
  91.   
  92.     cout<<"*********************************************************"<<endl;  
  93.     cout<<"start to train negative samples..."<<endl;  
  94.   
  95.     char negativeImgName[256];  
  96.     for (int i=0; i<negativeSampleCount; i++)  
  97.     {    
  98.         memset(negativeImgName, '\0', 256*sizeof(char));  
  99.         sprintf(negativeImgName, "%d.jpg", i);  
  100.         path = negativePath + negativeImgName;  
  101.         cv::Mat img = cv::imread(path);  
  102.         if(img.data == NULL)  
  103.         {  
  104.             cout<<"negative image sample load error: "<<path<<endl;  
  105.             continue;  
  106.         }  
  107.   
  108.         cv::HOGDescriptor hog(cv::Size(64,64), cv::Size(16,16), cv::Size(8,8), cv::Size(8,8), 9);    
  109.         vector<float> featureVec;   
  110.   
  111.         hog.compute(img,featureVec,cv::Size(8,8));//計算HOG特徵  
  112.         int featureVecSize = featureVec.size();    
  113.   
  114.         for ( int j=0; j<featureVecSize; j ++)    
  115.         {    
  116.             CV_MAT_ELEM( *sampleFeaturesMat, float, i + positiveSampleCount, j ) = featureVec[ j ];  
  117.         }    
  118.   
  119.         sampleLabelMat->data.fl[ i + positiveSampleCount ] = -1;  
  120.     }    
  121.   
  122.     cout<<"end of training for negative samples..."<<endl;  
  123.     cout<<"********************************************************"<<endl;  
  124.     cout<<"start to train for SVM classifier..."<<endl;  
  125.   
  126.     CvSVMParams params;    
  127.     params.svm_type = CvSVM::C_SVC;    
  128.     params.kernel_type = CvSVM::LINEAR;    
  129.     params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 1000, FLT_EPSILON);  
  130.     params.C = 0.01;  
  131.   
  132.     Mysvm svm;  
  133.     svm.train( sampleFeaturesMat, sampleLabelMat, NULL, NULL, params ); //用SVM線性分類器訓練  
  134.     svm.save(classifierSavePath);  
  135.   
  136.     cvReleaseMat(&sampleFeaturesMat);  
  137.     cvReleaseMat(&sampleLabelMat);  
  138.   
  139.     int supportVectorSize = svm.get_support_vector_count();  
  140.     cout<<"support vector size of SVM:"<<supportVectorSize<<endl;  
  141.     cout<<"************************ end of training for SVM ******************"<<endl;  
  142.   
  143.     CvMat *sv,*alp,*re;//所有樣本特徵向量   
  144.     sv  = cvCreateMat(supportVectorSize , 1764, CV_32FC1);  
  145.     alp = cvCreateMat(1 , supportVectorSize, CV_32FC1);  
  146.     re  = cvCreateMat(1 , 1764, CV_32FC1);  
  147.     CvMat *res  = cvCreateMat(1 , 1, CV_32FC1);  
  148.   
  149.     cvSetZero(sv);  
  150.     cvSetZero(re);  
  151.     
  152.     for(int i=0; i<supportVectorSize; i++)  
  153.     {  
  154.         memcpy( (float*)(sv->data.fl+i*1764), svm.get_support_vector(i), 1764*sizeof(float));      
  155.     }  
  156.   
  157.     double* alphaArr = svm.get_alpha();  
  158.     int alphaCount = svm.get_alpha_count();  
  159.   
  160.     for(int i=0; i<supportVectorSize; i++)  
  161.     {  
  162.         alp->data.fl[i] = alphaArr[i];  
  163.     }  
  164.     cvMatMul(alp, sv, re);  
  165.   
  166.     int posCount = 0;  
  167.     for (int i=0; i<1764; i++)  
  168.     {  
  169.         re->data.fl[i] *= -1;  
  170.     }  
  171.   
  172.     FILE* fp = fopen("c:/hogSVMDetector-peopleFlow.txt","wb");  
  173.     if( NULL == fp )  
  174.     {  
  175.         return 1;  
  176.     }  
  177.     for(int i=0; i<1764; i++)  
  178.     {  
  179.         fprintf(fp,"%f \n",re->data.fl[i]);  
  180.     }  
  181.     float rho = svm.get_rho();  
  182.     fprintf(fp, "%f", rho);  
  183.     cout<<"c:/hogSVMDetector.txt 儲存完畢"<<endl;//儲存HOG能識別的分類器  
  184.     fclose(fp);  
  185.   
  186.     return 1;  
  187. }  
接著,再給出利用訓練好的分類器進行行人檢測的參考程式碼:

[cpp] view plain copy
  1. void Detect()  
  2. {  
  3.     CvCapture* cap = cvCreateFileCapture("E:\\02.avi");  
  4.     if (!cap)  
  5.     {  
  6.         cout<<"avi file load error..."<<endl;  
  7.         system("pause");  
  8.         exit(-1);  
  9.     }  
  10.   
  11.     vector<float> x;  
  12.     ifstream fileIn("c:/hogSVMDetector-peopleFlow.txt", ios::in);  
  13.     float val = 0.0f;  
  14.     while(!fileIn.eof())  
  15.     {  
  16.         fileIn>>val;  
  17.         x.push_back(val);  
  18.     }  
  19.     fileIn.close();  
  20.   
  21.     vector<cv::Rect>  found;  
  22.     cv::HOGDescriptor hog(cv::Size(64,64), cv::Size(16,16), cv::Size(8,8), cv::Size(8,8), 9);  
  23.     hog.setSVMDetector(x);  
  24.   
  25.     IplImage* img = NULL;  
  26.     cvNamedWindow("img", 0);  
  27.     while(img=cvQueryFrame(cap))  
  28.     {  
  29.         hog.detectMultiScale(img, found, 0, cv::Size(8,8), cv::Size(32,32), 1.05, 2);  
  30.         if (found.size() > 0)  
  31.         {  
  32.   
  33.             for (int i=0; i<found.size(); i++)  
  34.             {  
  35.                 CvRect tempRect = cvRect(found[i].x, found[i].y, found[i].width, found[i].height);  
  36.   
  37.                 cvRectangle(img, cvPoint(tempRect.x,tempRect.y),  
  38.                     cvPoint(tempRect.x+tempRect.width,tempRect.y+tempRect.height),CV_RGB(255,0,0), 2);  
  39.             }  
  40.         }  
  41.     }  
  42.     cvReleaseCapture(&cap);  
  43. }  

相關文章