多工學習模型之DBMTL介紹與實現

阿里雲大資料AI技術發表於2022-03-10

本文介紹的是阿里在2019年發表的多工學習演算法。該模型顯示地建模目標間的貝葉斯網路因果關係,整合建模了特徵和多個目標之間的複雜因果關係網路,省去了一般MTL模型中較強的獨立假設。由於不對目標分佈做任何特定假設,使得它能夠比較自然地推廣到任意形式的目標上。

1.多工學習背景

目前工業中使用的推薦演算法已不只侷限在單目標(ctr)任務上,還需要關注後續的轉化鏈路,如是否評論、收藏、加購、購買、觀看時長等目標。


多工學習模型之DBMTL介紹與實現

常見的多目標最佳化模型是從每個最佳化目標單獨的模型網路出發,透過讓這些網路在底層共享引數,實現各目標相關模型的適當程度的獨立性和相關性。這類的模型框架可以用上圖的結構來概括。不論底層如何共享引數,這些網路在最後幾層都要伸出一些獨立分支來預測各個目標的最終值。此類網路的機率模型可以用下述公式描述:

多工學習模型之DBMTL介紹與實現

其中l,m 為目標,x為樣本特徵,H為模型。這裡做了各目標獨立的假設。

2.DBMTL介紹

DBMTL(Deep Bayesian Multi-Target Learning)的一個出發點就是解決上述問題。事實上套用簡單的貝葉斯公式,機率模型可以寫成:

多工學習模型之DBMTL介紹與實現

如下圖所示,DBMTL與傳統MTL結構(認為各目標獨立)最主要差別在於構建了target node之間的貝葉斯網路,顯式建模了目標間可能存在的因果關係。因為在實際業務中,使用者的很多行為往往存在明顯的序列先後依賴關係,例如在資訊流場景,使用者要先點進圖文詳情頁,才會進行後續的瀏覽/評論/轉發/收藏 等操作。DBMTL在模型結構中體現了這些關係,因此,往往能學到更好的結果。

多工學習模型之DBMTL介紹與實現

下圖是DBMTL模型的具體實現。網路包含輸入層、共享embedding層、共享層,區別層和貝葉斯層。


多工學習模型之DBMTL介紹與實現

  • 共享embedding層是一個共享的lookup table,為各個target訓練所共享。
  • 共享層和分離層是一般的multilayer perceptron (MLP),分別建模各目標的共享/區別表示。
  • Bayesian層是DBMTL中最重要的部分。它實現瞭如下的機率模型:


多工學習模型之DBMTL介紹與實現


其對應的log-likelihood損失函式為:

多工學習模型之DBMTL介紹與實現


實際應用中,對不同目標調權仍有著較大的現實作用。當對目標賦予不同權重時,相當於把損失函式重新表達為:

多工學習模型之DBMTL介紹與實現

在網路的貝葉斯層中,函式f1, f2, f3 被實現為全連線的MLP,以學習目標間的隱含因果關係。他們把函式輸入變數的embedding級聯作為輸入,並輸入一個表示函式輸出變數的embedding。每一個目標的embedding最後再經過一層MLP以輸出最終目標的機率。

3.程式碼實現

基於 EasyRec推薦演算法框架,我們實現了DBMTL演算法,具體實現可移步至github: EasyRec-DBMTL

EasyRec介紹:EasyRec是阿里雲端計算平臺機器學習PAI團隊開源的大規模分散式推薦演算法框架,EasyRec 正如其名字一樣,簡單易用,整合了諸多優秀前沿的推薦系統論文思想,並且在實際工業落地中取得優良效果的特徵工程方法,整合訓練、評估、部署,與阿里雲產品無縫銜接,可以藉助 EasyRec 在短時間內搭建起一套前沿的推薦系統。作為阿里雲的拳頭產品,現已穩定服務於數百個企業客戶。

模型前饋網路

def build_predict_graph(self):
    """Forward function.
    Returns:
      self._prediction_dict: Prediction result of two tasks.
    """
    # 此處從共享embedding層後的tensor(self._features)開始,省略其生成邏輯
    
    # shared layer
    if self._model_config.HasField('bottom_dnn'):
        bottom_dnn = dnn.DNN(
            self._model_config.bottom_dnn,
            self._l2_reg,
            name='bottom_dnn',
            is_training=self._is_training)
        bottom_fea = bottom_dnn(self._features)
    else:
        bottom_fea = self._features
    # MMOE block
    if self._model_config.HasField('expert_dnn'):
        mmoe_layer = mmoe.MMOE(
            self._model_config.expert_dnn,
            l2_reg=self._l2_reg,
            num_task=self._task_num,
            num_expert=self._model_config.num_expert)
        task_input_list = mmoe_layer(bottom_fea)
    else:
        task_input_list = [bottom_fea] * self._task_num
    tower_features = {}
    # specific layer
    for i, task_tower_cfg in enumerate(self._model_config.task_towers):
        tower_name = task_tower_cfg.tower_name
        if task_tower_cfg.HasField('dnn'):
            tower_dnn = dnn.DNN(
                task_tower_cfg.dnn,
                self._l2_reg,
                name=tower_name + '/dnn',
                is_training=self._is_training)
            tower_fea = tower_dnn(task_input_list[i])
            tower_features[tower_name] = tower_fea
        else:
            tower_features[tower_name] = task_input_list[i]
    tower_outputs = {}
    relation_features = {}
    # bayesian network
    for task_tower_cfg in self._model_config.task_towers:
        tower_name = task_tower_cfg.tower_name
        relation_dnn = dnn.DNN(
            task_tower_cfg.relation_dnn,
            self._l2_reg,
            name=tower_name + '/relation_dnn',
            is_training=self._is_training)
        tower_inputs = [tower_features[tower_name]]
        for relation_tower_name in task_tower_cfg.relation_tower_names:
            tower_inputs.append(relation_features[relation_tower_name])
        relation_input = tf.concat(
            tower_inputs, axis=-1, name=tower_name + '/relation_input')
        relation_fea = relation_dnn(relation_input)
        relation_features[tower_name] = relation_fea
        output_logits = tf.layers.dense(
            relation_fea,
            task_tower_cfg.num_class,
            kernel_regularizer=self._l2_reg,
            name=tower_name + '/output')
        tower_outputs[tower_name] = output_logits
        self._add_to_prediction_dict(tower_outputs)

Loss計算

def build(loss_type, label, pred, loss_weight=1.0, num_class=1, **kwargs):
    if loss_type == LossType.CLASSIFICATION:
        if num_class == 1:
            return tf.losses.sigmoid_cross_entropy(
              label, logits=pred, weights=loss_weight, **kwargs)
        else:
            return tf.losses.sparse_softmax_cross_entropy(
              labels=label, logits=pred, weights=loss_weight, **kwargs)
    elif loss_type == LossType.CROSS_ENTROPY_LOSS:
        return tf.losses.log_loss(label, pred, weights=loss_weight, **kwargs)
    elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
        logging.info('%s is used' % LossType.Name(loss_type))
        return tf.losses.mean_squared_error(
            labels=label, predictions=pred, weights=loss_weight, **kwargs)
    elif loss_type == LossType.PAIR_WISE_LOSS:
        return pairwise_loss(pred, label)
    else:
        raise ValueError('unsupported loss type: %s' % LossType.Name(loss_type))
def _build_loss_impl(self,
                     loss_type,
                     label_name,
                     loss_weight=1.0,
                     num_class=1,
                     suffix=''):
    loss_dict = {}
    if loss_type == LossType.CLASSIFICATION:
        loss_name = 'cross_entropy_loss' + suffix
        pred = self._prediction_dict['logits' + suffix]
    elif loss_type in [LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS]:
        loss_name = 'l2_loss' + suffix
        pred = self._prediction_dict['y' + suffix]
    else:
        raise ValueError('invalid loss type: %s' % LossType.Name(loss_type))
        loss_dict[loss_name] = build(loss_type,
                                     self._labels[label_name], 
                                     pred,
                                     loss_weight, num_class)
    return loss_dict
def build_loss_graph(self):
    """Build loss graph for multi task model."""
    for task_tower_cfg in self._task_towers:
        tower_name = task_tower_cfg.tower_name
        loss_weight = task_tower_cfg.weight * self._sample_weight
        if hasattr(task_tower_cfg, 'task_space_indicator_label') and \
        task_tower_cfg.HasField('task_space_indicator_label'):
            in_task_space = tf.to_float(
                self._labels[task_tower_cfg.task_space_indicator_label] > 0)
            loss_weight = loss_weight * (
                task_tower_cfg.in_task_space_weight * in_task_space +
                task_tower_cfg.out_task_space_weight * (1 - in_task_space))
            # EasyRec框架會自動對self._loss_dict中的loss進行加和。
            self._loss_dict.update(
                self._build_loss_impl(
                    task_tower_cfg.loss_type,
                    label_name=self._label_name_dict[tower_name],
                    loss_weight=loss_weight,
                    num_class=task_tower_cfg.num_class,
                    suffix='_%s' % tower_name))
    return self._loss_dict

4.應用

由於其卓越的演算法效果,DBMTL在 PAI上被大量使用。

以某直播推薦業務為例,該場景有is_click, is_view, view_costtime, is_on_mic, on_mic_duration多個目標,其中is_click, is_view, is_on_mic為二分類任務,view_costtime, on_mic_duration為預測時長的迴歸任務。使用者行為的依賴關係為:

  • is_click=> is_view
  • is_click+is_view=> view_costtime
  • is_click=> is_on_mic
  • is_click+is_on_mic => on_mic_duration

因此配置如下:

dbmtl {
  bottom_dnn {
  hidden_units: [512, 256]
}
task_towers {
  tower_name: "is_click"
  label_name: "is_click"
  loss_type: CLASSIFICATION
  metrics_set: {
  auc {}
}
dnn {
  hidden_units: [128, 96, 64]
}
relation_dnn {
  hidden_units: [32]
}
weight: 1.0
}
task_towers {
  tower_name: "is_view"
  label_name: "is_view"
  loss_type: CLASSIFICATION
  metrics_set: {
  auc {}
}
dnn {
  hidden_units: [128, 96, 64]
}
relation_tower_names: ["is_click"]
relation_dnn {
  hidden_units: [32]
}
weight: 1.0
}
task_towers {
  tower_name: "view_costtime"
  label_name: "view_costtime"
  loss_type: L2_LOSS
  metrics_set: {
  mean_squared_error {}
}
dnn {
  hidden_units: [128, 96, 64]
}
relation_tower_names: ["is_click", "is_view"]
relation_dnn {
  hidden_units: [32]
}
weight: 1.0
}    
task_towers {
  tower_name: "is_on_mic"
  label_name: "is_on_mic"
  loss_type: CLASSIFICATION
  metrics_set: {
  auc {}
}
dnn {
  hidden_units: [128, 96, 64]
}
relation_tower_names: ["is_click"]
relation_dnn {
  hidden_units: [32]
}
weight: 1.0
}
task_towers {
  tower_name: "on_mic_duration"
  label_name: "on_mic_duration"
  loss_type: L2_LOSS
  metrics_set: {
  mean_squared_error {}
}
dnn {
  hidden_units: [128, 96, 64]
}
relation_tower_names: ["is_click", "is_on_mic"]
relation_dnn {
  hidden_units: [32]
}
weight: 1.0
}
l2_regularization: 1e-6
}
embedding_regularization: 5e-6
}

值得一提的是,DBMTL模型上線後,相比GBDT+FM(圍觀單目標)線上圍觀率提升18%,上麥率提升14%。

5.參考文獻

EasyRec-DBMTL模型介紹

EasyRec-DBMTL模型原始碼

注:本文圖片及公式均引用自論文: DBMTL論文


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/70004426/viewspace-2869168/,如需轉載,請註明出處,否則將追究法律責任。

相關文章