關於sklearn下class_weight引數的一點原始碼閱讀與測試
版權宣告:歡迎轉載,請註明原出處 https://blog.csdn.net/go_og/article/details/81281387
一直沒有很在意過sklearn的class_weight的這個引數的具體作用細節,只大致瞭解是是用於處理樣本不均衡。後來在簡書上閱讀svm鬆弛變數的一些推導的時候,看到樣本不均衡的帶來的問題時候,想更深層次的看一下class_weight的具體作用方式,
svm鬆弛變數的簡書連結:https://www.jianshu.com/p/8a499171baa9
該文中的樣本不均衡的描述:
“樣本偏斜是指資料集中正負類樣本數量不均,比如正類樣本有10000個,負類樣本只有100個,這就可能使得超平面被“推向”負類(因為負類數量少,分佈得不夠廣),影響結果的準確性。”
隨後翻開sklearn LR的原始碼:
我們以分類作為說明重點
在輸入引數class_weight=‘balanced’的時候:
-
# compute the class weights for the entire dataset y
-
if class_weight == "balanced":
-
class_weight = compute_class_weight(class_weight,
-
np.arange(len(self.classes_)),
-
y)
-
class_weight = dict(enumerate(class_weight))
進一步閱讀 compute_class_weight這個函式:
-
elif class_weight == 'balanced':
-
# Find the weight of each class as present in y.
-
le = LabelEncoder()
-
y_ind = le.fit_transform(y)
-
if not all(np.in1d(classes, le.classes_)):
-
raise ValueError("classes should have valid labels that are in y")
-
recip_freq = len(y) / (len(le.classes_) *
-
np.bincount(y_ind).astype(np.float64))
-
weight = recip_freq[le.transform(classes)]
compute_class_weight這個函式的作用是對於輸入的樣本,平衡類別之間的權重,下面寫段測試程式碼測試這個函式:
-
# coding:utf-8
-
from sklearn.utils.class_weight import compute_class_weight
-
class_weight = 'balanced'
-
label = [0] * 9 + [1]*1 + [2, 2]
-
print(label) # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2]
-
classes=[0, 1, 2]
-
weight = compute_class_weight(class_weight, classes, label)
-
print(weight) #[ 0.44444444 4. 2. ]
-
print(.44444444 * 9) # 3.99999996
-
print(4 * 1) # 4
-
print(2 * 2) # 4
如上圖所示,可以看到這個函式把樣本的平衡後的權重乘積為4,每個類別均如此。
關於class_weight與sample_weight在損失函式上的具體計算方式:
-
sample_weight *= class_weight_[le.fit_transform(y_bin)] # sample_weight 與 class_weight相乘
-
# Logistic loss is the negative of the log of the logistic function.
-
out = -np.sum(sample_weight * log_logistic(yz)) + .5 * alpha * np.dot(w, w)
上述可以看出對於每個樣本,計算的損失函式乘上對應的sample_weight來計算最終的損失。這樣計算而來的損失函式不會因為樣本不平衡而被“推向”樣本量偏少的類別中。
class_weight以及sample_weight並沒有進行不平衡資料的處理,比如,上下采樣。詳細參見SMOTE EasyEnsemble等。
--------------------- 本文來自 摸摸小松鼠寶寶 的CSDN 部落格 ,全文地址請點選:https://blog.csdn.net/go_og/article/details/81281387?utm_source=copy
相關文章
- 如何模擬一個XMLHttpRequest請求用於單元測試——nise原始碼閱讀與分析XMLHTTP原始碼
- 【原始碼閱讀】Glide原始碼閱讀之with方法(一)原始碼IDE
- 測試用例驅動閱讀Express原始碼Express原始碼
- 【原始碼閱讀】AndPermission原始碼閱讀原始碼
- gin 原始碼閱讀(1) - gin 與 net/http 的關係原始碼HTTP
- Kingfisher原始碼閱讀(一)原始碼
- 關於如何看原始碼的一點思考原始碼
- 【測試】Android Studio 相關下載及引數Android
- 【原始碼閱讀】Glide原始碼閱讀之into方法(三)原始碼IDE
- Appdash原始碼閱讀——Annotations與EventAPP原始碼
- Appdash原始碼閱讀——Recorder與CollectorAPP原始碼
- 閱讀 Composer 原始碼的一個分享原始碼
- 關於JDK原始碼:我想聊聊如何更高效地閱讀JDK原始碼
- 【原始碼閱讀】Glide原始碼閱讀之load方法(二)原始碼IDE
- 逐行閱讀redux原始碼(一) createStoreRedux原始碼
- 【詳解】ThreadPoolExecutor原始碼閱讀(一)thread原始碼
- opentracing-go原始碼閱讀一Go原始碼
- Dive Into Code: VSCode 原始碼閱讀(一)VSCode原始碼
- 如何閱讀一份原始碼?原始碼
- 分享一些閱讀Java相關框架原始碼的經驗Java框架原始碼
- thinkphp5.1原始碼閱讀與學習(一、路由解析)PHP原始碼路由
- ReactorKit原始碼閱讀React原始碼
- Vollery原始碼閱讀(—)原始碼
- NGINX原始碼閱讀Nginx原始碼
- ThreadLocal原始碼閱讀thread原始碼
- 原始碼閱讀-HashMap原始碼HashMap
- Runtime 原始碼閱讀原始碼
- RunLoop 原始碼閱讀OOP原始碼
- AmplifyImpostors原始碼閱讀原始碼
- stack原始碼閱讀原始碼
- CountDownLatch原始碼閱讀CountDownLatch原始碼
- fuzz原始碼閱讀原始碼
- HashMap 原始碼閱讀HashMap原始碼
- delta原始碼閱讀原始碼
- AQS原始碼閱讀AQS原始碼
- Mux 原始碼閱讀UX原始碼
- ConcurrentHashMap原始碼閱讀HashMap原始碼
- HashMap原始碼閱讀HashMap原始碼