GBDT+LR原理和實戰
GBDT+LR簡單實戰
1. 背景
在早期的廣告點選預測(CTR)建模中,工業界用的最頻繁的模型是邏輯迴歸LR。由於邏輯迴歸本質上是線性模型(線性擬合後經過sigmoid變換),是十分容易並行化的,處理工業界當中的億級資料速度非常快,因此受到青睞。但是另一方面,邏輯迴歸的簡單導致了其學習能力有限,需要大量的特徵工程來增強其預測能力。不過傳統特徵工程主要集中在尋找有區分度的特徵或者特徵組合,這種耗時費力的工作卻不一定能帶來穩定的提升,因此GBDT這種自動學習特徵的樹模型便被應用了過來。
2014年Facebook的文章介紹了通過GBDT+LR融合建模的思路,隨後kaggle比賽中越來越多coder嘗試了這種模型,並取得了很好的效果,於是GBDT+LR被工業界關注並應用。
2. GBDT+LR的結構
GBDT+LR總的來講分為三步:
- 原始資料訓練GBDT模型
- GBDT模型將原始資料(訓練集、測試集)轉化為0-1稀疏向量
- 用2中得到的訓練集稀疏向量訓練LR
- LR對2中得到測試集稀疏向量預測
結合Facebook文章中的圖來闡述:
圖中共有兩棵樹,x為一條輸入樣本,遍歷兩棵樹後,x樣本分別落到兩顆樹的葉子節點上,每個葉子節點對應LR一維特徵,那麼通過遍歷樹,就得到了該樣本對應的所有LR特徵。構造的新特徵向量是取值0/1的。舉例來說:上圖有兩棵樹,左樹有三個葉子節點,右樹有兩個葉子節點,最終的特徵即為五維的向量。對於輸入x,假設他落在左樹第一個節點,編碼[1,0,0],落在右樹第二個節點則編碼[0,1],所以整體的編碼為[1,0,0,0,1],這類編碼作為特徵,輸入到LR中進行分類。
那麼為什麼採用GBDT呢?
- 首先GBDT是整合模型,相較於單顆樹對於原始資料特徵的表達能力更強,也更能發現有效的特徵和特徵組合;
- 其次相較於隨機森林RF,GBDT前邊的樹學習到的特徵是原始資料中區分度較大的特徵,後邊的樹中的特徵是經過前邊的樹,殘差仍然較大的特徵,因此採用GBDT學到的特徵是先選擇區分度較大的特徵,再選擇殘差較大的特徵,思路更加清晰合理。
3. 程式碼實戰
-
資料集
# implementation of GBDT+LR # import relative library import numpy as np import lightgbm as lgb from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression from sklearn.metrics import confusion_matrix, accuracy_score, f1_score # step1. preprocess dataset X, y = load_breast_cancer(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, random_state=1) train_data = lgb.Dataset(X_train, y_train)
-
訓練GBDT模型
這裡訓練GBDT模型可以通過sklearn中GBDT訓練,也可以採用LightGBM,將boosting_type設定為GBDT來訓練。LightGBM可以設定單顆樹的葉子節點數
# step2. adapt lightgbm to train gbdt model param_dict = { 'objective': 'binary', 'task': 'train', 'boosting_type': 'gbdt', 'metric': {'binary_logloss'}, 'learning_rate': 0.01, 'max_depth': 10, # 控制單顆樹的深度 'num_trees': 50, # 數的棵樹 'num_leaves': 30, # 每棵樹的葉子節點數 'verbose': 0, 'subsample': 0.8, } gbdt_model = lgb.train(params=param_dict, train_set=train_data)
-
將原始資料轉化為0-1稀疏向量
# step3. transform train/test data to 0-1 vector like one-hot num_trees = 50 num_leaves = 30 y_pred = gbdt_model.predict(X_train, pred_leaf=True) # y_pred 每一行中的數值 代表原始資料輸入gbdt模型後落在每顆數葉子節點位置的索引 # y_pred.shape # (398, 50) # y_pred[0] # array([ 0, 5, 9, 13, 9, 14, 9, 16, 16, 14, 14, 16, 15, 9, 15, 9, 15, # 13, 15, 9, 15, 11, 15, 9, 6, 9, 17, 16, 16, 9, 15, 10, 15, 10, # 15, 15, 16, 15, 15, 11, 15, 16, 14, 16, 15, 9, 16, 15, 14, 15], # dtype=int32) # 建立0矩陣,維度數為 單顆樹葉子節點數*樹的棵樹 transform_train_data = np.zeros([len(X_train), num_leaves*num_trees], dtype=np.int64) for i in range(len(X_train)): index = np.arange(num_trees) * num_leaves + y_pred[i] # 每棵樹中值為1的節點在整體向量中的索引 transform_train_data[i][index] += 1 # test data is similar y_test_pred = gbdt_model.predict(X_test, pred_leaf=True) transform_test_data = np.zeros([len(X_test), num_leaves*num_trees], dtype=np.int64) for i in range(len(X_test)): index = np.arange(num_trees) * num_leaves + y_pred[i] transform_test_data[i][index] += 1
-
訓練LR模型
# step4. train LR model lr_model = LogisticRegression(penalty='l2', C=0.05, solver='liblinear') lr_model.fit(transform_train_data, y_train)
-
預測和評估
# step5. predict and evaluate pred_res = lr_model.predict(transform_test_data) conf_res = confusion_matrix(pred_res, y_test) # 混淆矩陣 acc = accuracy_score(pred_res, y_test) # accuracy f1_score = f1_score(y_true=y_test, y_pred=pred_res) # f1值
相關文章
- Kafka 原理和實戰Kafka
- go-micro整合RabbitMQ實戰和原理GoMQ
- Delta Lake 資料湖原理和實戰
- LLM大模型:deepspeed實戰和原理解析大模型
- Redis Sentinel-深入淺出原理和實戰Redis
- RocketMQ實戰疑問和原理解答(實時更新)MQ
- Keepalived 原理與實戰
- nohup 原理及實戰
- bitMap原理及實戰
- Maven實戰與原理分析(二):maven實戰Maven
- 【推薦系統】GBDT+LR
- 【深入淺出Spring原理及實戰】「原始碼原理實戰」從底層角度去分析研究PropertySourcesPlaceholderConfigurer的原理及實戰注入機制Spring原始碼
- JMeter實戰(二) 執行原理JMeter
- 【linux】helloword原理分析及實戰Linux
- 直播原理與web直播實戰Web
- python 爬蟲實戰的原理Python爬蟲
- OpenTelemetry 實戰:gRPC 監控的實現原理RPC
- Istio 流量治理功能原理與實戰
- Spring Boot自動配置原理、實戰Spring Boot
- RecyclerView 事件分發原理實戰分析View事件
- Spring Security 入門原理及實戰Spring
- JSON WEB TOKEN 從原理到實戰JSONWeb
- Metro拆包工作原理與實戰
- sqlx操作MySQL實戰及其ORM原理MySqlORM
- 過載保護原理與實戰
- 為什麼是InfluxDB | 寫在《InfluxDB原理和實戰》出版之際UX
- PaddlePaddle實戰 | 情感分析演算法從原理到實戰全解演算法
- 從原理到實戰,徹底搞懂NginxNginx
- Hyperledger Fabric原理詳解與實戰1
- Hyperledger Fabric原理詳解與實戰4
- Android 除錯實戰與原理詳解Android除錯
- apache common pool2原理與實戰Apache
- Flutter完整開發實戰詳解(十三、全面深入觸控和滑動原理)Flutter
- 基於gRPC的註冊發現與負載均衡的原理和實戰RPC負載
- 阿里大牛實戰歸納——Kafka架構原理阿里Kafka架構
- 從原理到實戰,詳解XXE攻擊
- 官方工具|MySQL Router 高可用原理與實戰MySql
- Selenium原理、安裝與自動打卡實戰