最近在處理一個多分類的歸類問題。多分類問題是個很常見的問題。但由於大家問題描述的角度不同,糾結點不一樣,用詞不同,導致有時候理解上出現了what?的尷尬情況。但透過現象看本質,抓住機器學習本質就是找公式,解方程,解不出來就儘量優化這個核心思想,就不會lost。
什麼是歸類。
在影像歸類中,
手寫數字識別,預測目標是0,1,2...9中的某一個數字
自動駕駛中的目標檢測,預測目標是汽車,行人,建築物
在自然語言處理中,
情感分類,預測目標文字所要表達的情感是高興,悲傷,憤怒
傳統的邏輯迴歸分類問題,是二項分類,是非問題,就是預測目標是不是屬於某個分類。比如說,是新冠病毒攜帶者嗎?是合理的嗎?是壞的嗎?當把傳統的二項分類邏輯迴歸演算法衍生到多分類的時候,該如何處理呢。因為邏輯迴歸演算法比較簡單,一個線性方程再加一個啟用函式,解決問題能力有限,因此一個模型直接用於多分類,效果不理想。(圖片來自網路)
所以,處理多分類問題的時候,通常就是分解為多個二項模型,然後一一判斷是不是屬於某個分類。
對於手寫數字識別問題,就是有10個二項分類模型,分別預測是不是0,是不是1。。。
這樣的做法,就是很可能,一個目標可能會落入多個分類中。對於實際問題,確實也有屬於多分類的情況。比如說,自然語言處理中的命名實體識別NER,某個片語可能同時歸類於”組織”和”時間發生地”。可以說多個二項分類模型的組合是default的解決了目標可能處於多個分類的這個問題。
但是如何處理只能是其中一個分類的問題呢。我們知道,邏輯迴歸的輸出其實是代表概率,為了避免又是1又是0的問題,最後以選取概率最大的分類,作為最後的目標分類,就可以解決這個問題。
當我們用多層神經網路來取代羅輯迴歸演算法的時候,由於網路的複雜,可以fit的情況越來越多,也就是可以用更加複雜的方程來匹配樣本。(邏輯迴歸分類只有一層,怎麼整都沒法匹配複雜樣本的) 。這個時候,多個分類問題並不需要分成多個網路來分別處理。
這時候需要保證的是輸出層的節點數必須與分類的個數是一致的。從概念上,可以從更通用的角度來理解,這個模型網路就是一個n維(n個feature)的向量到一個m維(m個分類)適量的transformer。(圖片來自網路)
而以上提到的兩種歸類問題(多個分類可以共存與否)被轉換為兩個互相對立的概念: multi-class(不可共存,排他) 和multi-label(可以共存)
其實我很不喜歡把它們對立起來,明明就是同樣的網路模型嘛,唯一的區別就是啟用函式。multi-class 的啟用函式用softmax,這樣保證最後只有一個class被選中。multi-label的啟用函式用sigmoid, 不管別人怎樣,我根據我自己的標準來歸類。啟用函式不一樣,最後的loss function的cross entropy的公式自然也不一樣,但本質上就是交叉熵,最大似然。
所以再用pre-trained的模型時候,至少兩件事情必須考慮,根據類別數量確定輸出層的節點數,根據是多分類問題還是多標籤問題確定啟用函式。
當然有衍生的複雜問題。例如對目標進行不同角度的分類,比如說一首音樂對於風格和流行度同時進行分類,一個排他一個不排他。回到概念的本質,其實也一樣,定位輸出,找出優化方程(loss function)。
閱讀作者更多原創文章,關注微信公眾號: