一種基於均值不等式的Listwise損失函式

infgrad發表於2020-10-06

一種基於均值不等式的Listwise損失函式

1 前言

1.1 Learning to Rank 簡介

Learning to Rank (LTR) , 也被叫做排序學習, 是搜尋中的重要技術, 其目的是根據候選文件和查詢語句的相關性對候選文件進行排序, 或者選取topk文件. 比如在搜尋引擎中, 需要根據使用者問題選取最相關的搜尋結果展示到首頁. 下圖是搜尋引擎的搜尋結果
search_result.jpg

1.2 LTR演算法分類

根據損失函式可把LTR分為三種:

  1. Pointwise, 該型別演算法將LTR任務作為迴歸任務來訓練, 即嘗試訓練一個為文件和查詢語句的打分器, 然後根據打分進行排序.
  2. Pairwise, 該型別演算法的損失函式考慮了兩個候選文件, 學習目標是把相關性高的文件排在前面, triplet loss 就屬於Pairwise, 它的損失函式是

\[loss = max(0, score_{neg}-score_{pos}+margin) \]

可以看出該損失函式一次考慮兩個候選文件.
3. Listwise, 該型別演算法的損失函式會考慮多個候選文件, 這是本文的重點, 下面會詳細介紹.

1.3 本文主要內容

本文主要介紹了本人在學習研究過程中發明的一種新的Listwise損失函式, 以及該損失函式的使用效果. 如果讀者對LTR任務及其演算法還不夠熟悉, 建議先去學習LTR相關知識, 同時本人博文自然語言處理中的負樣本挖掘 (分類與排序任務中如何選擇負樣本) 也和本文關係較大, 可以先進行閱讀.

2 預備知識

2.1 數學符號定義

\(q\)代表使用者搜尋問題, 比如"如何成為宇航員", \(D\)代表候選文件集合,\(d^+\)代表和\(q\)相關的文件,\(d^-\)代表和\(q\)不相關的文件, \(d^+_i\)代表第\(i\)個和\(q\)相關的文件, LTR的目標就是根據\(q\)找到最相關的文件\(d\)

2.2 學習目標

本次學習目標是訓練一個打分器 scorer, 它可以衡量q和d的相關性, \(scorer(q, d)\)就是相關性分數,分值越大越相關. 當前主流方法下, scorer一般選用深度神經網路模型.

2.3訓練資料分類

損失函式不同, 構造訓練資料的方法也會不同:

-Pointwise, 可以構造迴歸資料集, 相關的資料設為1, 不相關設為0.
-Pairwise, 可構造triplet型別的資料集, 形如(\(q,d^+, d^-\))
-Listwise, 可構造這種型別的訓練集: (\(q,d^+_1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-_{n+m}\)), 一個正例還是多個正例也會影響到損失函式的構造, 本文提出的損失函式是針對多正例多負例的情況.

3 基於均值不等式的Listwise損失函式

3.1 損失函式推導過程

在上一小結我們可以知道,訓練集是如下形式 (\(q,d^+_1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-_{n+m}\)), 對於一個q, 有n個相關的文件和m個不相關的文件, 那麼我們一共可以獲取m+n個分值:\((score_1,score_2,...,score_n,...,score_{n+m})\), 我們希望打分器對相關文件打分趨近於正無窮, 對不相關文件打分趨近於負無窮.

對m+n個分值做一個softmax得到\(p_1,p_2,...,p_n,...,p_{n+m}\), 此時\(p_i\)可以看作是第i個候選文件與q相關的概率, 顯然我們希望\(p_1,p_2,...,p_n\)越大越好, \(p_{n+1},...,p_{m+n}\)越小越好, 即趨近於0. 因此我們暫時的優化目標是\(\sum_{i=1}^{n}{p_i} \rightarrow 1\).

但是這個優化目標是不合理的, 假設\(p_1=1\), 其他值全為0, 雖然滿足了上面的要求, 但這並不是我們想要的. 因為我們不僅希望\(\sum_{i=1}^{n}{p_i} \rightarrow 1\), 還希望相關候選文件的每一個p值都要足夠大, 即我們希望: n個候選文件都與q相關的概率是最大的, 所以我們真正的優化目標是:

\[\max(\prod_{i=1}^{n}{p_i} ) , \sum_{i=1}^{n}{p_i} = 1 \]

當前情況下, 損失函式已經可以通過程式碼實現了, 但是我們還可以做一些化簡工作, \(\prod_{i=1}^{n}{p_i}\)是存在最大值的, 根據均值不等式可得:

\[\prod_{i=1}^{n}{p_i} \leq (\frac{\sum_{i=1}^{n}{p_i}}{n})^n \]

對兩邊取對數:

\[\sum_{i=1}^{n}{log(p_i)} \leq -nlog(n) \]

這樣是不是感覺清爽多了, 然後我們把它轉換成損失函式的形式:

\[loss = -nlog(n) - \sum_{i=1}^{n}{log(p_i)} \]

所以我們的訓練目標就是\(\min{(loss)}\)

3.2 使用pytorch實現該損失函式

在獲取到最終的損失函式後, 我們還需要用程式碼來實現, 實現程式碼如下:

# A simple example for my listwise loss function
# Assuming that n=3, m=4
# In[1]
# scores
scores = torch.tensor([[3,4.3,5.3,0.5,0.25,0.25,1]])
print(scores)
print(scores.shape)
'''
tensor([[0.3000, 0.3000, 0.3000, 0.0250, 0.0250, 0.0250, 0.0250]])
torch.Size([1, 7])
'''
# In[2]
# log softmax
log_prob = torch.nn.functional.log_softmax(scores,dim=1)
print(log_prob)
'''
tensor([[-2.7073, -1.4073, -0.4073, -5.2073, -5.4573, -5.4573, -4.7073]])
'''
# In[3]
# compute loss
n = 3.
mask = torch.tensor([[1,1,1,0,0,0,0]]) # number of 1 is n
loss = -1*n*torch.log(torch.tensor([[n]])) - torch.sum(log_prob*mask,dim=1,keepdim=True)
print(loss)
loss = loss.mean()
print(loss)
'''
tensor([[1.2261]])
tensor(1.2261)
'''

該示例程式碼僅展現了batch_size為1的情況, 在batch_size大於1時, 每一條資料都有不同的m和n, 為了能一起送入模型計算分值, 需要靈活的使用mask. 本人在實際使用該損失函式時,一共使用了兩種mask, 分別mask每條資料所有候選文件和每條資料的相關文件, 供大家參考使用.

3.3 效果評估和使用經驗

由於評測資料使用的是內部資料, 程式碼和資料都無法公開, 因此只能對使用效果做簡單總結:

  1. 效果優於PointwisePairwise, 但差距不是特別大
  2. 相比Pairwise收斂速度極快, 訓練一輪基本就可以達到最佳效果

下面是個人使用經驗:

  1. 該損失函式比較佔用視訊記憶體, 實際的batch_size是batch_size*(m+n), 建議視訊記憶體在12G以上
  2. 負例數量越多,效果越好, 收斂也越快
  3. 用pytorch實現log_softmax時, 不要自己實現, 直接使用torch中的log_softmax函式, 它的效率更高些.
  4. 只有一個正例, 還可以考慮轉為分類問題,使用交叉熵做優化, 效果同樣較好

4 總結

該損失函式還是比較簡單的, 只需要簡單的數學知識就可以自行推導, 在實際使用中也取得了較好的效果, 希望也能夠幫助到大家. 如果大家有更好的做法歡迎告訴我.

文章可以轉載, 但請註明出處:

相關文章