本文作者用python程式碼示例解釋了3種處理不平衡資料集的可選方法,包括資料層面上的2種重取樣資料集方法和演算法層面上的1個整合分類器方法。
分類是機器學習最常見的問題之一,處理它的最佳方法是從分析和探索資料集開始,即從探索式資料分析(Exploratory Data Analysis, EDA)開始。除了生成儘可能多的資料見解和資訊,它還用於查詢資料集中可能存在的任何問題。在分析用於分類的資料集時,類別不平衡是常見問題之一。
什麼是資料不平衡(類別不平衡)?
資料不平衡通常反映了資料集中類別的不均勻分佈。例如,在信用卡欺詐檢測資料集中,大多數信用卡交易型別都不是欺詐,僅有很少一部分型別是欺詐交易,如此以來,非欺詐交易和欺詐交易之間的比率達到50:1。本文中,我將使用來自Kaggle的信用卡欺詐交易資料資料集,你可以從這裡下載。
這裡
https://www.kaggle.com/mlg-ulb/creditcardfraud
首先,我們先繪製類分佈圖,檢視不平衡情況。
如你所見,非欺詐交易型別資料數量遠遠超過欺詐交易型別。如果我們在不解決這個類別不平衡問題的情況下訓練了一個二分類模型,那麼這個模型完全是有偏差的,稍後我還會向你演示它影響特徵相關性的過程並解釋其中的原因。
現在,我們來介紹一些解決類別不平衡問題的技巧,你可以在這裡找到完整程式碼的notebook。
這裡
https://github.com/wmlba/innovate2019/blob/master/Credit_Card_Fraud_Detection.ipynb
一、 重取樣(過取樣和欠取樣)
這聽起來很直接。欠取樣就是一個隨機刪除一部分多數類(數量多的型別)資料的過程,這樣可以使多數類資料數量可以和少數類(數量少的型別)相匹配。一個簡單實現程式碼如下:
# Shuffle the Dataset.
shuffled_df = credit_df.sample(frac=1,random_state=4)
# Put all the fraud class in a separate dataset.
fraud_df = shuffled_df.loc[shuffled_df['Class'] == 1]
#Randomly select 492 observations from the non-fraud (majority class)
non_fraud_df=shuffled_df.loc[shuffled_df['Class']== 0].sample(n=492,random_state=42)
# Concatenate both dataframes again
normalized_df = pd.concat([fraud_df, non_fraud_df])
#plot the dataset after the undersampling
plt.figure(figsize=(8, 8))
sns.countplot('Class', data=normalized_df)
plt.title('Balanced Classes')
plt.show()
對多數類進行欠取樣
對資料集進行欠取樣之後,我重新畫出了型別分佈圖(如下),可見兩個型別的數量相等。
平衡資料集(欠取樣)
第二種重取樣技術叫過取樣,這個過程比欠取樣複雜一點。它是一個生成合成資料的過程,試圖學習少數類樣本特徵隨機地生成新的少數類樣本資料。對於典型的分類問題,有許多方法對資料集進行過取樣,最常見的技術是SMOTE(Synthetic Minority Over-sampling Technique,合成少數類過取樣技術)。簡單地說,就是在少數類資料點的特徵空間裡,根據隨機選擇的一個K最近鄰樣本隨機地合成新樣本。
來源
https://imbalanced-learn.readthedocs.io/en/stable/over_sampling.html
為了用python編碼,我呼叫了imbalanced-learn 庫(或imblearn),實現SMOTE的程式碼如下:
imbalanced-learn
https://imbalanced-learn.readthedocs.io/en/stable/index.html
from imblearn.over_sampling import SMOTE
# Resample the minority class. You can change the strategy to 'auto' if you are not sure.
sm = SMOTE(sampling_strategy='minority', random_state=7)
# Fit the model to generate the data.
oversampled_trainX,oversampled_trainY=sm.fit_sample(credit_df.drop('Class', axis=1), credit_df['Class'])
oversampled_train=pd.concat([pd.DataFrame(oversampled_trainY), pd.DataFrame(oversampled_trainX)], axis=1)
oversampled_train.columns = normalized_df.columns
還記得我說過不平衡的資料會影響特徵相關性嗎?讓我向您展示處理不平衡類問題前後的特徵相關性。
重取樣之前:
下面的程式碼用來繪製所有特徵之間的相關矩陣:
# Sample figsize in inches
fig, ax = plt.subplots(figsize=(20,10))
# Imbalanced DataFrame Correlation
corr = credit_df.corr()
sns.heatmap(corr, cmap='YlGnBu', annot_kws={'size':30}, ax=ax)
ax.set_title("Imbalanced Correlation Matrix", fontsize=14)
plt.show()
重取樣之後:
請注意,現在特徵相關性更明顯了。在解決不平衡問題之前,大多數特徵並沒有顯示出相關性,這肯定會影響模型的效能。除了會關係到整個模型的效能,特徵性相關性還會影響ML模型的效能,因此修復類別不平衡問題非常重要。
會關係到整個模型的效能
https://towardsdatascience.com/why-feature-correlation-matters-a-lot-847e8ba439c4
二、 整合方法(取樣器整合)
在機器學習中,整合方法會使用多種學習演算法和技術,以獲得比單獨使用其中一個演算法更好的效能(是的,就像一個民主投票系統)。當使用集合分類器時,bagging方法變得流行起來,它通過構建多個分類器在隨機選擇的不同資料集上進行訓練。在scikit-learn庫中,有一個名叫“BaggingClassifier”的整合分類器,然而這個分類器不能訓練不平衡資料集。當訓練不平衡資料集時,這個分類器將會偏向多數類,從而建立一個有偏差的模型。
為了解決這個問題,我們可以使用imblearn庫中的BalancedBaggingClassifier。它允許在訓練整合分類器中每個子分類器之前對每個子資料集進行重取樣。
BalancedBaggingClassifier
https://mp.weixin.qq.com/cgi-bin/appmsg?t=media/appmsg_edit&action=edit&type=10&isMul=1&isNew=1&lang=zh_CN&token=89565677#imblearn.ensemble.BalancedBaggingClassifier
因此,BalancedBaggingClassifier除了需要和Scikit Learn BaggingClassifier相同的引數以外,還需要2個引數sampling_strategy和replacement來控制隨機取樣器的執行。下面是具體的執行程式碼:
from imblearn.ensemble import BalancedBaggingClassifier
from sklearn.tree import DecisionTreeClassifier
#Create an object of the classifier.
bbc = BalancedBaggingClassifier(base_estimator=DecisionTreeClassifier(),
sampling_strategy='auto',
replacement=False,
random_state=0)
y_train = credit_df['Class']
X_train = credit_df.drop(['Class'], axis=1, inplace=False)
#Train the classifier.
bbc.fit(X_train, y_train)
preds = bbc.predict(X_train)
使用集合取樣器訓練不平衡資料集
這樣,您就可以訓練一個分類器來處理類別不平衡問題,而不必在訓練前手動進行欠取樣或過取樣。
總之,每個人都應該知道,建立在不平衡資料集上的ML模型會難以準確預測稀有點和少數點,整體效能會受到限制。因此,識別和解決這些點的不平衡對生成模型的質量和效能是至關重要的。
原文標題:
How to fix an Unbalanced Dataset
原文連結:
https://www.kdnuggets.com/2019/05/fix-unbalanced-dataset.html