[C++ & AdaBoost] 傻陳帶你用C++實現AdaBoost

IcyLeaves發表於2018-11-28

前言

人工智慧的演算法需要許多預備知識,但時間比較緊,所以我只會對"對最後演算法實現有幫助的資料"感興趣,試著在不完全瞭解的情況下將這次實驗完成。

資料集

這次要學習的資料集是UCI上的Polish companies bankruptcy data Data Set
資料集就是一堆資料的集合,而這次實驗就是要對這個“破產公司”資料集進行分類。
這個資料集包含了5個.arff格式的檔案,分別對應五年的資料。
我借鑑《Matlab讀取.arff檔案》的教程,用Excel檢視了裡面的內容,一個檔案大概有七千條公司的記錄,每條記錄都有65個屬性。其中前64個屬性是實數格式,個別可能出現’?’,應該是‘不清楚’的意思;最後一個屬性是0或1,代表了這條記錄的類別(我猜測是破產/不破產)。
我要做的就是,用這些記錄的屬性和類別去給演算法學習,達到能只憑屬性來預測類別的程度。

AdaBoost

機器學習實戰教程(十):提升分類器效能利器-AdaBoost
網上有大把的參考資料還算慶幸,這裡舉了個例子。當然要短時間完全理解還是天方夜譚的。

思路

AdaBoost大致是靠調整弱分類器的權重來運作的。
弱分類器可以暫時理解為對某個屬性設定一個閾值,如果達到閾值就直接歸為類1,未達到就歸為類0。
一開始Adaboost匯入作為訓練資料的記錄,根據一個弱分類器來分類所有的記錄。然後對比判斷結果與每條記錄的實際類別的差距,來糾正這個弱分類器的權重和各個記錄的權重,最後就智慧的學習出了由多個帶權重的弱分類器合併的強分類器,同時也可以通過記錄的權重來排除不必要的訓練記錄。

設計

  • 訓練模式
    • 匯入訓練資料
    • 對訓練資料定權重D={k,k,k,…,k,k},暫時先不知道k為多少合適,總之全員相等。
    • 載入訓練資料,使用弱分類器來預測分類,最後選定誤差最小的一個弱分類器開始訓練
    • 給分錯的資料和分對的資料分別調整資料的權重
    • 對比預測和實際,然後通過某種方式來設定分類器權重α。
    • 將帶權重的弱分類器加入到強分類器中,得出強分類器的當前指標。
    • 迴圈迭代進行下一次訓練。

難點

1.如何將.arff裡的內容輸入到C++程式呢?
2.k應該是多少呢?
3.如何建立初始的弱分類器
4.某種方式是什麼呢?
這些純屬程式碼之外的知識盲區,還是優先解決的好。

Solve 1

在前面我已經成功的用Excel檢視了裡面的內容,如果用記事本開啟的話,資料部分其實就是用逗號分隔的記錄們。
.arff檔案是無法識別的,但我可以單純的將資料部分作為新的輸入檔案,然後以逗號分割整個檔案,得到的集合就都是數字為元素的,再以每65個為一條記錄即可。

Solve 2

初始權重在AdaBoost的教程中其實已經規定,為w=1/N,作為第一次迭代時各個記錄的權重。可以說非常標準了。

Solve 3

雖然AdaBoost有加強分類器的能力,但是儘量還是讓弱分類器本身的誤差率最小比較好。所以在AdaBoost之前,弱分類器也是要訓練出來的。
因為弱分類器本身是很簡單的,類似
H1={1,X35>9.81,X359.8 H_{1}=\left\{\begin{matrix} 1,X_{35}>9.8\\ -1,X_{35}\leq 9.8 \end{matrix}\right. 表示若記錄的第35個屬性>9.8,那麼就取1,反之就取-1。
分類預測時,就看分類器的結果是正還是負,正為類1,負為類0。
所以誤差率就是分錯的記錄佔總記錄數的比例。
因此很容易想到,只要在某個屬性的資料範圍內,遍歷各種可能的閾值和方向(大還是小),取其中誤差最小的即可。
例如:資料分佈在-10~90之間,我可以以10為步長遍歷,也就是從-10,0, … ,80,90中挑出一個分類結果最準確的閾值即可。

Solve 4

Solve 4是Adaboost的核心計算部分,先要從原理上了解才能寫出相應的程式碼。
Adaboost演算法原理分析和例項+程式碼(簡明易懂)
1.選取一個誤差率最小的弱分類器H,對所有資料分類,計算當前弱分類器的誤差度。
誤差度ε等於分錯的樣本的權值之和。
2.計算該弱分類器的權重
α=12ln(1εε) \alpha=\frac{1}{2}ln(\frac{1-\varepsilon }{\varepsilon })
3.計算分對記錄和分錯記錄的權重
當記錄被分對時,該記錄新權重為:
Dnew=Dold2(1ε) D_{new}=\frac{D_{old}}{2(1-\varepsilon )}
當記錄被分錯時,該記錄新權重為:
Dnew=Dold2ε D_{new}=\frac{D_{old}}{2\varepsilon }

程式設計

因為想以實現預期效果為目標,所以執行效率和最終結果會有點不盡人意。

輸入輸出環境配置

資料集是以檔案的形式存在的,想到輸入輸出都走檔案渠道

#include <iostream>
#include <string>
#include <cstring>
#include <cmath>
#include <sstream>
#include <cstdlib>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <map>
using namespace std;

#define LOCAL

int main()
{
#ifdef LOCAL  
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
#endif  
	return 0;
}

程式碼說明
1.若Visual Studio出現報錯"freopen:This function…",那麼就在解決方案資源管理器裡,右鍵專案名>屬性>C/C++>常規>SDL檢查>“是"改成"否”。
2.這段程式碼會使程式以input.txt作為標準輸入流,然後輸出會在output.txt中顯示

輸入資料集

//訓練集記錄的結構體
struct Record
{
	bool hasValue[64];//64個屬性是否存在
	double attr[64];//64個屬性的值
	double weight;//記錄的權重
	int type;//類別,0或1
};

vector<Record> recs;//訓練集

//功能:遍歷字串str,將裡面的','替換為空格
void replaceComma(string& str)
{
	for (int i = 0; i < str.length(); i++)
	{
		if (str[i] == ',') 
			str[i] = ' ';
	}
}

//功能:輸入訓練集檔案,將訓練集的資料儲存在recs之中
void inputRecords()
{
	Record tempRec;
	string line;
	while (!cin.eof())
	{
		//一行作為一條完整記錄
		getline(cin, line);
		if (line == "") break;
		replaceComma(line);

		stringstream ss(line);//一行的字串作為輸入流
		for (int i = 0; i < 64; i++)
		{
			//排除值為'?'的屬性
			string tempStr;
			ss >> tempStr;
			if (tempStr != "?")
			{
				tempRec.hasValue[i] = true;
				//把string型別轉化為double型別
				stringstream tempSS(tempStr);
				tempSS >> tempRec.attr[i];
			}
			else
				tempRec.hasValue[i] = false;
		}
		ss >> tempRec.type;//記錄的最後一個數字代表類別

		//記錄+1
		recs.push_back(tempRec);
	}
	//給所有記錄賦初始權重
	double numOfRecords = recs.size();
	for (vector<Record>::iterator it = recs.begin(); it != recs.end(); it++)
	{
		(*it).weight = 1.0 / numOfRecords;
	}
}

程式碼說明
1.vector<Record>是宣告一個Record結構體的動態陣列,與之配套的是迭代器vector<Record>::iterator,用來指向動態陣列中的元素。
recs.size():表示動態陣列的大小,也就是訓練集的記錄個數。
recs.push_back(Record):將一個Record結構體插入到陣列的最後(自動開闢空間)。
recs.begin()->recs.end():迭代器從開始到結束,就相當於遍歷了整個recs。
(*it):迭代器所指向的Record結構體元素。
2.字串流stringstream能將一個字串轉變成同樣內容的一段輸入,這樣做的好處是能先提取單行的一整條記錄,然後以空格為分隔符,將一個個屬性值直接輸入到recs.attr中。

初始弱分類器

//弱分類器的結構體
struct WeakClassfier
{
	int attrIdx;//哪個屬性上的弱分類器
	double cap;//閾值大小
	bool isHigher;//方向
	double weight;//分類器權重
	WeakClassfier(int a, double c, bool h, double w)
	{
		attrIdx = a;
		cap = c;
		isHigher = h;
		weight = w;
	}
};

vector<WeakClassfier> weaks;//弱分類器們//功能:在屬性attrIdx(0~63)上訓練一個弱分類器
void trainWeakClassifier(int attrIdx)
{
	//在所有記錄的該屬性上找到最大值和最小值,並且統計有效值的個數
	double Max = -9999;
	double Min = 9999;
	int nums = 0;
	for (vector<Record>::iterator it = recs.begin(); it != recs.end(); it++)
	{
		if ((*it).hasValue[attrIdx] == true)
		{
			double currentValue = (*it).attr[attrIdx];
			nums++;
			if (currentValue> Max)
			{
				Max = currentValue;
			}
			else if (currentValue < Min)
			{
				Min = currentValue;
			}
		}
	}

	//確定合適的步長
	double step = (Max-Min)/10.0;

	//尋找準確率最大的閾值和方向
	double bestCap = -9999;
	double bestAcc = 0;
	bool isHigherBetter = true;
	for (double cap = Min + step; cap < Max; cap += step)
	{
		//用於計算準確率的變數
		int correctHigher = 0;
		int correctLower = 0;
		//統計分類準確率
		for (vector<Record>::iterator it = recs.begin(); it != recs.end(); it++)
		{
			Record currentRecord = (*it);
			//Higher:  a>=cap 為 類1
			//Lower:  a<cap 為 類1
			if ((*it).attr[attrIdx] >= cap)
			{
				correctHigher += currentRecord.type == 1 ? 1 : 0;
				correctLower += currentRecord.type == 0 ? 1 : 0;
			}
			else
			{
				correctHigher += currentRecord.type == 0 ? 1 : 0;
				correctLower += currentRecord.type == 1 ? 1 : 0;
			}
		}
		//是否比當前的弱分類器準確率更高
		if (correctHigher >= correctLower)
		{
			double tempAcc = correctHigher*1.0 / nums;
			if (tempAcc > bestAcc)
			{
				bestCap = cap;
				bestAcc = tempAcc;
				isHigherBetter = true;
			}
		}
		else
		{
			double tempAcc = correctLower*1.0 / nums;
			if (tempAcc > bestAcc)
			{
				bestCap = cap;
				bestAcc = tempAcc;
				isHigherBetter = false;
			}
		}
	}
	weaks.push_back(WeakClassfier(attrIdx, bestCap, isHigherBetter, 1.0));
}

程式碼說明
1.9999的取值非常隨意,請不要模仿。
2.這樣訓練出的分類器相當於如下函式:
if(H.isHigher==true),H={1,attr[attrIdx]&gt;=cap1,attr[attrIdx]&lt;cap if\left (H.isHigher==true \right ),H=\left\{\begin{matrix} 1, attr[attrIdx]&gt;=cap\\ -1,attr[attrIdx]&lt;cap \end{matrix}\right.
if(H.isHigher==false),H={1,attr[attrIdx]&lt;cap1,attr[attrIdx]&gt;=cap if\left (H.isHigher==false \right ),H=\left\{\begin{matrix} 1, attr[attrIdx]&lt;cap\\ -1,attr[attrIdx]&gt;=cap \end{matrix}\right. 3.因為這裡只平均取了10個閾值,所以對於資料分佈範圍廣又在中間密集的屬性,比如-80萬~400萬的42號屬性,這個分類器容易表現不佳。

選擇最好弱分類器

//功能:弱分類器函式H,可以將傳入的分類器和屬性值進行計算
int funH(int weakIdx, double attrVal)
{
	WeakClassfier weak = weaks[weakIdx];
	if (weak.isHigher)
	{
		if (attrVal >= weak.cap) return 1;
		else return -1;
	}
	else
	{
		if (attrVal < weak.cap) return 1;
		else return -1;
	}
}

//功能:符號判別函式
int sign(double a)
{
	if (a >= 0) return 1;
	else return 0;
}

//功能:取誤差率最小的一個分類器(編號)
int getBestWeakClassfier()
{
	//初始化
	for (vector<WeakClassfier>::iterator wit = weaks.begin(); wit != weaks.end(); wit++)
	{
		(*wit).errorRate = 0;
	}
	//遍歷所有記錄和弱分類器
	for (vector<Record>::iterator rit = recs.begin(); rit != recs.end(); rit++)
	{
		Record rec = (*rit);
		double recWeight = (*rit).weight;
		for (vector<WeakClassfier>::iterator wit = weaks.begin(); wit != weaks.end(); wit++)
		{
			int attrIdx = (*wit).attrIdx;
			
			if (!rec.hasValue[attrIdx])//若記錄rec的attrIdx屬性值是'?',就認定分類錯誤
			{
				(*wit).errorRate += recWeight;
			}
			else
			{
				if (sign(funH((*wit).attrIdx, rec.attr[attrIdx])) != rec.type)
				{
					(*wit).errorRate += recWeight;
				}
			}
		}
	}
	//找出誤差率errorRate最小的分類器返回
	int bestIdx = 0;
	double minErrorRate = 2;
	for (vector<WeakClassfier>::iterator wit = weaks.begin(); wit != weaks.end(); wit++)
	{
		if ((*wit).errorRate < minErrorRate)
		{
			minErrorRate = (*wit).errorRate;
			bestIdx = (*wit).attrIdx;
		}
	}
	return bestIdx;
}

程式碼說明
1.按照之前的演算法,誤差率可以通過把分錯樣本的權重相加得到。

調整分類器權重

//2.調整該分類器的權重
 weak.weight = 0.5*log((1 - weak.errorRate) / weak.errorRate);

程式碼說明
1.log(x)相當於ln(x),即C++中預設為自然對數。

調整記錄權重

//功能:用該分類器再次進行分類,來調整記錄的權重
void adjustRecordWeight(int attrIdx)
{
	WeakClassfier weak = weaks[attrIdx];
	for (vector<Record>::iterator it = recs.begin(); it != recs.end(); it++)
	{
		if (!(*it).hasValue[attrIdx])//若記錄rec的attrIdx屬性值是'?',就認定分類錯誤
		{
			(*it).weight /= 2 * weak.errorRate;
		}
		else
		{
			if (sign(funH(attrIdx, (*it).attr[attrIdx])) != (*it).type)//分錯
			{
				(*it).weight /= 2 * weak.errorRate;
			}
			else//分對
			{
				(*it).weight /= 2 * (1 - weak.errorRate);
			}
		}
	}
}

重新定義強分類器

vector<WeakClassfier> strong;//一個強分類器

{
//4.強分類器併入此弱分類器
strong.push_back(WeakClassfier(weak));
}

程式碼說明
1.強分類器是將各個帶權重的弱分類器相加所得,所以用一個可以儲存多個弱分類器結構體的陣列來表示強分類器。最後計算時,依次呼叫其中的弱分類器即可。

強分類器分類指標

//功能:用強分類器進行分類,得到當前強分類器各個指標
void strongClassfier()
{
	int TP = 0;//True Positive(TP):實際為1,判定結果為1
	int FP = 0;//False Positive(FP):實際為0,判定結果為1
	int TN = 0;//True Negative(TN):實際為0,判定結果為0
	int FN = 0;//False Negative(FN):實際為1,判定結果為0
	for (vector<Record>::iterator it = recs.begin(); it != recs.end(); it++)
	{
		Record rec = (*it);
		double sum = 0;
		for (vector<WeakClassfier>::iterator sit = strong.begin(); sit != strong.end(); sit++)
		{
			WeakClassfier weak = (*sit);
			if (!rec.hasValue[weak.attrIdx])//若記錄rec的attrIdx屬性值是'?',就認定分類錯誤
			{
				sum -= weak.weight;
			}
			else
			{
				sum += funH(weak.attrIdx, rec.attr[weak.attrIdx])*weak.weight;
			}
		}
		if (sign(sum) == rec.type)//判定結果和實際情況相同
		{
			if (rec.type == 0)
			{
				TN++;
			}
			else
			{
				TP++;
			}
		}
		else//判定結果和實際情況不同
		{
			if (rec.type == 0)
			{
				FP++;
			}
			else
			{
				FN++;
			}
		}
	}
	//計算並輸出指標
	cout << "------T=" << strong.size() << "------" << endl;
	//精確率
	cout << "Precision:" << TP*1.0 / (TP + FP) << endl;
	//召回率
	cout << "Recall:" << TP*1.0 / (TP + FN) << endl;
	//準確率
	cout << "Accuracy:" << (TP + TN)*1.0 / (TP + TN + FP + FN) << endl<<endl;
}

程式碼說明
1.指標介紹請見[AI Algorithm] 評判預測效能的四個指標

Adaboost總流程

//功能:用所有弱分類器weaks和訓練集recs,訓練出一個強分類器
void Adaboost()
{
	for (int i = 0; i < TIMES; i++)
	{
		//1.取誤差率最小的一個分類器
		WeakClassfier& weak = weaks[getBestWeakClassfier()];
		//2.調整該分類器的權重
		weak.weight = 0.5*log((1 - weak.errorRate) / weak.errorRate);
		//3.用該分類器再次進行分類,來調整記錄的權重
		adjustRecordWeight(weak.attrIdx);
		//4.強分類器併入此弱分類器
		strong.push_back(WeakClassfier(weak));
		//5.用強分類器進行分類,得到當前各個指標
		strongClassfier();
	}
}

程式碼說明
1.TIMES是迭代的次數,理論上迭代次數越多,最後對訓練集的分類結果越精確。

總結

1.弱分類器的選擇很重要。本程式碼粗糙的弱分類器最終導致弱分類器已經有強分類器的效果,無論如何修改訓練集,Adaboost的迭代都失去了意義。Accuracy總是在前兩次迭代就達到了穩定。
2.在三天的時間內,匆匆理解原理並完成的程式碼還是有很多缺陷的。這裡只是提供一個大概的思路,讓自己對整個流程有個大概的瞭解,希望各位不要嚴肅對待:)

原始碼

#include <iostream>
#include <string>
#include <cstring>
#include <cmath>
#include <sstream>
#include <cstdlib>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;

#define LOCAL
#define TIMES 50

//訓練集記錄的結構體
struct Record
{
	bool hasValue[64];//64個屬性是否存在
	double attr[64];//64個屬性的值
	double weight;//記錄的權重
	int type;//類別,0或1
};

//弱分類器的結構體
struct WeakClassfier
{
	int attrIdx;//哪個屬性上的弱分類器
	double cap;//閾值大小
	bool isHigher;//方向
	double weight;//分類器權重
	double errorRate;//誤差率
	WeakClassfier(int a, double c, bool h, double w)
	{
		attrIdx = a;
		cap = c;
		isHigher = h;
		weight = w;
	}
};

vector<Record> recs;//訓練集
vector<WeakClassfier> weaks;//弱分類器們
vector<WeakClassfier> strong;//一個強分類器

//功能:遍歷字串str,將裡面的','替換為空格
void replaceComma(string& str)
{
	for (int i = 0; i < str.length(); i++)
	{
		if (str[i] == ',')
		{
			str[i] = ' ';
		}
	}

}

//功能:輸入訓練集檔案,將訓練集的資料儲存在recs之中
void inputRecords()
{
	Record tempRec;
	string line;
	while (!cin.eof())
	{
		//一行作為一條完整記錄
		getline(cin, line);
		if (line == "") break;
		replaceComma(line);

		stringstream ss(line);//一行的字串作為輸入流
		for (int i = 0; i < 64; i++)
		{
			//排除值為'?'的屬性
			string tempStr;
			ss >> tempStr;
			if (tempStr != "?")
			{
				tempRec.hasValue[i] = true;
				//把string型別轉化為double型別
				stringstream tempSS(tempStr);
				tempSS >> tempRec.attr[i];
				if (tempRec.attr[i] > 1000)
				{
					int a = 1;
				}
			}
			else
				tempRec.hasValue[i] = false;
		}
		ss >> tempRec.type;//記錄的最後一個數字代表類別

		//記錄+1
		recs.push_back(tempRec);
	}
	//給所有記錄賦初始權重
	double numOfRecords = recs.size();
	for (vector<Record>::iterator it = recs.begin(); it != recs.end(); it++)
	{
		(*it).weight = 1.0 / numOfRecords;
	}
}

//功能:在屬性attrIdx(0~63)上訓練一個弱分類器
void trainWeakClassifier(int attrIdx)
{
	//在所有記錄的該屬性上找到最大值和最小值,並且統計有效值的個數
	double Max = -9999;
	double Min = 9999;
	int nums = 0;
	for (vector<Record>::iterator it = recs.begin(); it != recs.end(); it++)
	{
		if ((*it).hasValue[attrIdx] == true)
		{
			double currentValue = (*it).attr[attrIdx];
			nums++;
			if (currentValue > Max)
			{
				Max = currentValue;
			}
			else if (currentValue < Min)
			{
				Min = currentValue;
			}
		}
	}

	//確定合適的步長
	double step = (Max - Min) / 10.0;

	//尋找準確率最大的閾值和方向
	double bestCap = -9999;
	double bestAcc = 0;
	bool isHigherBetter = true;
	for (double cap = Min + step; cap < Max; cap += step)
	{
		//用於計算準確率的變數
		int correctHigher = 0;
		int correctLower = 0;
		//統計分類準確率
		for (vector<Record>::iterator it = recs.begin(); it != recs.end(); it++)
		{
			Record currentRecord = (*it);
			//Higher:  a>=cap 為 類1
			//Lower:  a<cap 為 類1
			if ((*it).attr[attrIdx] >= cap)
			{
				correctHigher += currentRecord.type == 1 ? 1 : 0;
				correctLower += currentRecord.type == 0 ? 1 : 0;
			}
			else
			{
				correctHigher += currentRecord.type == 0 ? 1 : 0;
				correctLower += currentRecord.type == 1 ? 1 : 0;
			}
		}
		//是否比當前的弱分類器準確率更高
		if (correctHigher >= correctLower)
		{
			double tempAcc = correctHigher*1.0 / nums;
			if (tempAcc > bestAcc)
			{
				bestCap = cap;
				bestAcc = tempAcc;
				isHigherBetter = true;
			}
		}
		else
		{
			double tempAcc = correctLower*1.0 / nums;
			if (tempAcc > bestAcc)
			{
				bestCap = cap;
				bestAcc = tempAcc;
				isHigherBetter = false;
			}
		}
	}
	weaks.push_back(WeakClassfier(attrIdx, bestCap, isHigherBetter, 1.0));
}

//功能:弱分類器函式H,可以將傳入的分類器和屬性值進行計算
int funH(int weakIdx, double attrVal)
{
	WeakClassfier weak = weaks[weakIdx];
	if (weak.isHigher)
	{
		if (attrVal >= weak.cap) return 1;
		else return -1;
	}
	else
	{
		if (attrVal < weak.cap) return 1;
		else return -1;
	}
}

//功能:符號判別函式
int sign(double a)
{
	if (a >= 0) return 1;
	else return 0;
}

//功能:取誤差率最小的一個分類器(編號)
int getBestWeakClassfier()
{
	//初始化
	for (vector<WeakClassfier>::iterator wit = weaks.begin(); wit != weaks.end(); wit++)
	{
		(*wit).errorRate = 0;
	}
	//遍歷所有記錄和弱分類器
	for (vector<Record>::iterator rit = recs.begin(); rit != recs.end(); rit++)
	{
		Record rec = (*rit);
		double recWeight = (*rit).weight;
		for (vector<WeakClassfier>::iterator wit = weaks.begin(); wit != weaks.end(); wit++)
		{
			int attrIdx = (*wit).attrIdx;

			if (!rec.hasValue[attrIdx])//若記錄rec的attrIdx屬性值是'?',就認定分類錯誤
			{
				(*wit).errorRate += recWeight;
			}
			else
			{
				if (sign(funH((*wit).attrIdx, rec.attr[attrIdx])) != rec.type)
				{
					(*wit).errorRate += recWeight;
				}
			}
		}
	}
	//找出誤差率errorRate最小的分類器返回
	int bestIdx = 0;
	double minErrorRate = 2;
	for (vector<WeakClassfier>::iterator wit = weaks.begin(); wit != weaks.end(); wit++)
	{
		if ((*wit).errorRate < minErrorRate)
		{
			minErrorRate = (*wit).errorRate;
			bestIdx = (*wit).attrIdx;
		}
	}
	return bestIdx;
}

//功能:用該弱分類器再次進行分類,來調整記錄的權重
void adjustRecordWeight(int attrIdx)
{
	WeakClassfier weak = weaks[attrIdx];
	for (vector<Record>::iterator it = recs.begin(); it != recs.end(); it++)
	{
		if (!(*it).hasValue[attrIdx])//若記錄rec的attrIdx屬性值是'?',就認定分類錯誤
		{
			(*it).weight /= 2 * weak.errorRate;
		}
		else
		{
			if (sign(funH(attrIdx, (*it).attr[attrIdx])) != (*it).type)//分錯
			{
				(*it).weight /= 2 * weak.errorRate;
			}
			else//分對
			{
				(*it).weight /= 2 * (1 - weak.errorRate);
			}
		}
	}
}

//功能:用強分類器進行分類,得到當前強分類器各個指標
void strongClassfier()
{
	int TP = 0;//True Positive(TP):實際為1,判定結果為1
	int FP = 0;//False Positive(FP):實際為0,判定結果為1
	int TN = 0;//True Negative(TN):實際為0,判定結果為0
	int FN = 0;//False Negative(FN):實際為1,判定結果為0
	for (vector<Record>::iterator it = recs.begin(); it != recs.end(); it++)
	{
		Record rec = (*it);
		double sum = 0;
		for (vector<WeakClassfier>::iterator sit = strong.begin(); sit != strong.end(); sit++)
		{
			WeakClassfier weak = (*sit);
			if (!rec.hasValue[weak.attrIdx])//若記錄rec的attrIdx屬性值是'?',就認定分類錯誤
			{
				sum -= weak.weight;
			}
			else
			{
				sum += funH(weak.attrIdx, rec.attr[weak.attrIdx])*weak.weight;
			}
		}
		if (sign(sum) == rec.type)//判定結果和實際情況相同
		{
			if (rec.type == 0)
			{
				TN++;
			}
			else
			{
				TP++;
			}
		}
		else//判定結果和實際情況不同
		{
			if (rec.type == 0)
			{
				FP++;
			}
			else
			{
				FN++;
			}
		}
	}
	//計算並輸出指標
	cout << "------T=" << strong.size() << "------" << endl;
	//精確率
	cout << "Precision:" << TP*1.0 / (TP + FP) << endl;
	//召回率
	cout << "Recall:" << TP*1.0 / (TP + FN) << endl;
	//準確率
	cout << "Accuracy:" << (TP + TN)*1.0 / (TP + TN + FP + FN) << endl<<endl;
}

//功能:用所有弱分類器weaks和訓練集recs,訓練出一個強分類器
void Adaboost()
{
	for (int i = 0; i < TIMES; i++)
	{
		//1.取誤差率最小的一個分類器
		WeakClassfier& weak = weaks[getBestWeakClassfier()];
		//2.調整該分類器的權重
		weak.weight = 0.5*log((1 - weak.errorRate) / weak.errorRate);
		//3.用該分類器再次進行分類,來調整記錄的權重
		adjustRecordWeight(weak.attrIdx);
		//4.強分類器併入此弱分類器
		strong.push_back(WeakClassfier(weak));
		//5.用強分類器進行分類,得到當前各個指標
		strongClassfier();
	}
}

int main()
{
#ifdef LOCAL  
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
#endif  
	inputRecords();
	for (int i = 0; i < 64; i++)
	{
		trainWeakClassifier(i);
	}
	Adaboost();
	return 0;
}

相關文章