多工學習模型之DBMTL介紹與實現
本文介紹的是阿里在2019年發表的多工學習演算法。該模型顯示地建模目標間的貝葉斯網路因果關係,整合建模了特徵和多個目標之間的複雜因果關係網路,省去了一般MTL模型中較強的獨立假設。由於不對目標分佈做任何特定假設,使得它能夠比較自然地推廣到任意形式的目標上。
1.多工學習背景
目前工業中使用的推薦演算法已不只侷限在單目標(ctr)任務上,還需要關注後續的轉化鏈路,如是否評論、收藏、加購、購買、觀看時長等目標。
常見的多目標最佳化模型是從每個最佳化目標單獨的模型網路出發,透過讓這些網路在底層共享引數,實現各目標相關模型的適當程度的獨立性和相關性。這類的模型框架可以用上圖的結構來概括。不論底層如何共享引數,這些網路在最後幾層都要伸出一些獨立分支來預測各個目標的最終值。此類網路的機率模型可以用下述公式描述:
其中l,m 為目標,x為樣本特徵,H為模型。這裡做了各目標獨立的假設。
2.DBMTL介紹
DBMTL(Deep Bayesian Multi-Target Learning)的一個出發點就是解決上述問題。事實上套用簡單的貝葉斯公式,機率模型可以寫成:
如下圖所示,DBMTL與傳統MTL結構(認為各目標獨立)最主要差別在於構建了target node之間的貝葉斯網路,顯式建模了目標間可能存在的因果關係。因為在實際業務中,使用者的很多行為往往存在明顯的序列先後依賴關係,例如在資訊流場景,使用者要先點進圖文詳情頁,才會進行後續的瀏覽/評論/轉發/收藏 等操作。DBMTL在模型結構中體現了這些關係,因此,往往能學到更好的結果。
下圖是DBMTL模型的具體實現。網路包含輸入層、共享embedding層、共享層,區別層和貝葉斯層。
- 共享embedding層是一個共享的lookup table,為各個target訓練所共享。
- 共享層和分離層是一般的multilayer perceptron (MLP),分別建模各目標的共享/區別表示。
- Bayesian層是DBMTL中最重要的部分。它實現瞭如下的機率模型:
其對應的log-likelihood損失函式為:
實際應用中,對不同目標調權仍有著較大的現實作用。當對目標賦予不同權重時,相當於把損失函式重新表達為:
在網路的貝葉斯層中,函式f1, f2, f3 被實現為全連線的MLP,以學習目標間的隱含因果關係。他們把函式輸入變數的embedding級聯作為輸入,並輸入一個表示函式輸出變數的embedding。每一個目標的embedding最後再經過一層MLP以輸出最終目標的機率。
3.程式碼實現
基於 EasyRec推薦演算法框架,我們實現了DBMTL演算法,具體實現可移步至github: EasyRec-DBMTL。
EasyRec介紹:EasyRec是阿里雲端計算平臺機器學習PAI團隊開源的大規模分散式推薦演算法框架,EasyRec 正如其名字一樣,簡單易用,整合了諸多優秀前沿的推薦系統論文思想,並且在實際工業落地中取得優良效果的特徵工程方法,整合訓練、評估、部署,與阿里雲產品無縫銜接,可以藉助 EasyRec 在短時間內搭建起一套前沿的推薦系統。作為阿里雲的拳頭產品,現已穩定服務於數百個企業客戶。
模型前饋網路
def build_predict_graph(self):
"""Forward function.
Returns:
self._prediction_dict: Prediction result of two tasks.
"""
# 此處從共享embedding層後的tensor(self._features)開始,省略其生成邏輯
# shared layer
if self._model_config.HasField('bottom_dnn'):
bottom_dnn = dnn.DNN(
self._model_config.bottom_dnn,
self._l2_reg,
name='bottom_dnn',
is_training=self._is_training)
bottom_fea = bottom_dnn(self._features)
else:
bottom_fea = self._features
# MMOE block
if self._model_config.HasField('expert_dnn'):
mmoe_layer = mmoe.MMOE(
self._model_config.expert_dnn,
l2_reg=self._l2_reg,
num_task=self._task_num,
num_expert=self._model_config.num_expert)
task_input_list = mmoe_layer(bottom_fea)
else:
task_input_list = [bottom_fea] * self._task_num
tower_features = {}
# specific layer
for i, task_tower_cfg in enumerate(self._model_config.task_towers):
tower_name = task_tower_cfg.tower_name
if task_tower_cfg.HasField('dnn'):
tower_dnn = dnn.DNN(
task_tower_cfg.dnn,
self._l2_reg,
name=tower_name + '/dnn',
is_training=self._is_training)
tower_fea = tower_dnn(task_input_list[i])
tower_features[tower_name] = tower_fea
else:
tower_features[tower_name] = task_input_list[i]
tower_outputs = {}
relation_features = {}
# bayesian network
for task_tower_cfg in self._model_config.task_towers:
tower_name = task_tower_cfg.tower_name
relation_dnn = dnn.DNN(
task_tower_cfg.relation_dnn,
self._l2_reg,
name=tower_name + '/relation_dnn',
is_training=self._is_training)
tower_inputs = [tower_features[tower_name]]
for relation_tower_name in task_tower_cfg.relation_tower_names:
tower_inputs.append(relation_features[relation_tower_name])
relation_input = tf.concat(
tower_inputs, axis=-1, name=tower_name + '/relation_input')
relation_fea = relation_dnn(relation_input)
relation_features[tower_name] = relation_fea
output_logits = tf.layers.dense(
relation_fea,
task_tower_cfg.num_class,
kernel_regularizer=self._l2_reg,
name=tower_name + '/output')
tower_outputs[tower_name] = output_logits
self._add_to_prediction_dict(tower_outputs)
Loss計算
def build(loss_type, label, pred, loss_weight=1.0, num_class=1, **kwargs):
if loss_type == LossType.CLASSIFICATION:
if num_class == 1:
return tf.losses.sigmoid_cross_entropy(
label, logits=pred, weights=loss_weight, **kwargs)
else:
return tf.losses.sparse_softmax_cross_entropy(
labels=label, logits=pred, weights=loss_weight, **kwargs)
elif loss_type == LossType.CROSS_ENTROPY_LOSS:
return tf.losses.log_loss(label, pred, weights=loss_weight, **kwargs)
elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
logging.info('%s is used' % LossType.Name(loss_type))
return tf.losses.mean_squared_error(
labels=label, predictions=pred, weights=loss_weight, **kwargs)
elif loss_type == LossType.PAIR_WISE_LOSS:
return pairwise_loss(pred, label)
else:
raise ValueError('unsupported loss type: %s' % LossType.Name(loss_type))
def _build_loss_impl(self,
loss_type,
label_name,
loss_weight=1.0,
num_class=1,
suffix=''):
loss_dict = {}
if loss_type == LossType.CLASSIFICATION:
loss_name = 'cross_entropy_loss' + suffix
pred = self._prediction_dict['logits' + suffix]
elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
loss_name = 'l2_loss' + suffix
pred = self._prediction_dict['y' + suffix]
else:
raise ValueError('invalid loss type: %s' % LossType.Name(loss_type))
loss_dict[loss_name] = build(loss_type,
self._labels[label_name],
pred,
loss_weight, num_class)
return loss_dict
def build_loss_graph(self):
"""Build loss graph for multi task model."""
for task_tower_cfg in self._task_towers:
tower_name = task_tower_cfg.tower_name
loss_weight = task_tower_cfg.weight * self._sample_weight
if hasattr(task_tower_cfg, 'task_space_indicator_label') and \
task_tower_cfg.HasField('task_space_indicator_label'):
in_task_space = tf.to_float(
self._labels[task_tower_cfg.task_space_indicator_label] > 0)
loss_weight = loss_weight * (
task_tower_cfg.in_task_space_weight * in_task_space +
task_tower_cfg.out_task_space_weight * (1 - in_task_space))
# EasyRec框架會自動對self._loss_dict中的loss進行加和。
self._loss_dict.update(
self._build_loss_impl(
task_tower_cfg.loss_type,
label_name=self._label_name_dict[tower_name],
loss_weight=loss_weight,
num_class=task_tower_cfg.num_class,
suffix='_%s' % tower_name))
return self._loss_dict
4.應用
由於其卓越的演算法效果,DBMTL在 PAI上被大量使用。
以某直播推薦業務為例,該場景有is_click, is_view, view_costtime, is_on_mic, on_mic_duration多個目標,其中is_click, is_view, is_on_mic為二分類任務,view_costtime, on_mic_duration為預測時長的迴歸任務。使用者行為的依賴關係為:
- is_click=> is_view
- is_click+is_view=> view_costtime
- is_click=> is_on_mic
- is_click+is_on_mic => on_mic_duration
因此配置如下:
dbmtl {
bottom_dnn {
hidden_units: [512, 256]
}
task_towers {
tower_name: "is_click"
label_name: "is_click"
loss_type: CLASSIFICATION
metrics_set: {
auc {}
}
dnn {
hidden_units: [128, 96, 64]
}
relation_dnn {
hidden_units: [32]
}
weight: 1.0
}
task_towers {
tower_name: "is_view"
label_name: "is_view"
loss_type: CLASSIFICATION
metrics_set: {
auc {}
}
dnn {
hidden_units: [128, 96, 64]
}
relation_tower_names: ["is_click"]
relation_dnn {
hidden_units: [32]
}
weight: 1.0
}
task_towers {
tower_name: "view_costtime"
label_name: "view_costtime"
loss_type: L2_LOSS
metrics_set: {
mean_squared_error {}
}
dnn {
hidden_units: [128, 96, 64]
}
relation_tower_names: ["is_click", "is_view"]
relation_dnn {
hidden_units: [32]
}
weight: 1.0
}
task_towers {
tower_name: "is_on_mic"
label_name: "is_on_mic"
loss_type: CLASSIFICATION
metrics_set: {
auc {}
}
dnn {
hidden_units: [128, 96, 64]
}
relation_tower_names: ["is_click"]
relation_dnn {
hidden_units: [32]
}
weight: 1.0
}
task_towers {
tower_name: "on_mic_duration"
label_name: "on_mic_duration"
loss_type: L2_LOSS
metrics_set: {
mean_squared_error {}
}
dnn {
hidden_units: [128, 96, 64]
}
relation_tower_names: ["is_click", "is_on_mic"]
relation_dnn {
hidden_units: [32]
}
weight: 1.0
}
l2_regularization: 1e-6
}
embedding_regularization: 5e-6
}
值得一提的是,DBMTL模型上線後,相比GBDT+FM(圍觀單目標)線上圍觀率提升18%,上麥率提升14%。
5.參考文獻
注:本文圖片及公式均引用自論文: DBMTL論文
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/70004426/viewspace-2869168/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 多工學習模型之ESMM介紹與實現模型
- 深度學習之遷移學習介紹與使用深度學習遷移學習
- Android學習之 Scroller的介紹與使用Android
- Sunny.Xia的深度學習(四)MMOE多工學習模型實戰演練深度學習模型
- 深度學習之TensorFlow的介紹與安裝深度學習
- 推薦模型NeuralCF:原理介紹與TensorFlow2.0實現模型
- 1.Django介紹與學習Django
- mysql學習之-三種安裝方式與版本介紹MySql
- Dubbo原始碼學習之-SPI介紹原始碼
- 推薦模型DeepCrossing: 原理介紹與TensorFlow2.0實現模型ROS
- 當AI實現多工學習,它究竟能做什麼?AI
- 聯邦學習:多工思想與聚類聯邦學習聯邦學習聚類
- 深度學習之tensorflow2實戰:多輸出模型深度學習模型
- 表示學習介紹
- 架構學習-多工架構
- docker 學習筆記之實戰 lnmp 環境搭建系列 (1) —— docker 介紹與安裝Docker筆記LNMP
- 並行多工學習論文閱讀(一):多工學習速覽並行
- PHP 多程式與訊號中斷實現多工常駐記憶體管理【Master/Worker 模型】PHP記憶體AST模型
- ClickHouse學習系列之七【系統命令介紹】
- Redis學習 RDB和AOF兩種持久化介紹以及實現Redis持久化
- 遊戲引擎學習筆記:介紹、架構、設計及實現遊戲引擎筆記架構
- 深度學習與CV教程(8) | 常見深度學習框架介紹深度學習框架
- dapr學習:dapr介紹
- Presto學習-presto介紹REST
- 學習內容介紹
- python網路-多工實現之協程Python
- UI自動化學習筆記- PO模型介紹和使用UI筆記模型
- Java 多執行緒學習筆記(四)yield 介紹Java執行緒筆記
- JavaScript高階程式設計學習(一)之介紹JavaScript程式設計
- python greenlet背景介紹與實現機制Python
- OutputStreamWriter介紹&程式碼實現和InputStreamReader介紹&程式碼實現
- Redis(設計與實現):---釋出與訂閱介紹Redis
- Graphql學習(一)-GraphQL介紹
- 整合學習入門介紹
- 元學習簡單介紹
- 學習python前言介紹Python
- PHP+jQuery+Ajax實現多圖片上傳介紹PHPjQuery
- C#實現多語言介面程式的方法介紹C#