前些天從實驗室瞭解到天池的FashionAI全球挑戰賽,題目和資料都挺有意思,於是花了點時間稍微嘗試了下。目前比賽還在初賽階段,題目有兩個,分別是服裝屬性標籤識別和服飾關鍵點定位。
服裝屬性標籤識別是指識別出領、袖、衣、裙、褲等部位的設計屬性,對應多個多分類問題,例如以下的例子。
服飾關鍵點定位是指定位出服飾中關鍵點的位置,對應多個迴歸問題,例如以下的例子。
這和CelebA(mmlab.ie.cuhk.edu.hk/projects/Ce…)人臉資料集有點像,每張圖片都是一張人臉,對應的標註包括5個關鍵點位置和40個屬性的01標註,例如是否有眼鏡、帽子、鬍子等。
選題
我選了第二道題目,一方面感覺有人臉關鍵點檢測、人體骨骼關鍵點檢測等相關問題可供參考,去年的AI Challenger也舉辦過人體骨骼關鍵點檢測的比賽(challenger.ai/competition…),另一方面自己還沒有做過這塊內容,比較感興趣。
比賽官方提供的訓練集共包括4W多張圖片,測試集共包括將近1W張圖片,神奇的是訓練集和測試集有366張圖片是完全重合的。每張圖片都指定了對應的服飾,共5類:上衣(blouse)、外套(outwear)、連身裙(dress)、半身裙(skirt)、褲子(trousers)。
一共有24個關鍵點,但每類服飾對應的關鍵點數量並不一樣。關鍵點標註分為三類,-1表示不存在,0表示存在但不可見,1表示存在且可見,後兩種情況都會提供對應關鍵點的xy座標。
初步探索
進行了相關的預處理之後,先嚐試下最基礎的結構,即卷積、池化加全連線。由於之前完全沒做過這一塊,所以網路的一些細節都只能慢慢嘗試,包括卷積用幾層、卷積核大小設多少、使用哪個啟用函式、使用哪個損失函式等等。
進行了長時間的摸索,終於折騰出第一個全部打通的版本,提交了一版結果,成績大概30%左右。由於最後一層使用全連線層直接輸出每個關鍵點的xy座標,因此誤差比較大。
目前排行榜第一名是4.49%,佔榜很多天雷打不動,可見實力之強勁。
進一步探索
後來一想,像這類比較經典的問題,肯定已經有大量的相關研究和模型,完全靠自己憑空搭一個網路顯然不靠譜。於是進行了一些調研,找到兩個模型:Convolutional Pose Machine、Stacked Hourglass。
精力有限,重點研究了一下CPM。閱讀了對應的論文,2016年的CVPR,模型結構長這樣,簡單來說就是反覆使用多個Stage,不斷抽取每個關鍵點對應的越來越準確的響應圖。
在Github上找到了CPM的一個開源實現(github.com/timctho/con…),閱讀程式碼並進行修改後應用到比賽的資料上,在P100上訓練共花費30個小時左右。使用6個Stage的CPM,為每個關鍵點生成一張響應圖。
以下是一張dress對應的結果,第一行的三張依次是第1個、第2個、第3個Stage的響應圖合成結果,第二行的三張分別對應第6個Stage的響應圖合成結果、正確答案、正確答案和原圖的合成,看起來還不錯。
再來看個outwear,響應圖也很準。
最後再看個trousers,關鍵點比較少,也很準。
又交了一版結果,拿到了17%的成績。由於CPM輸出關鍵點的響應圖而不是直接輸出關鍵點的xy座標,同時使用多個Stage級聯以逐步獲取越來越準確的響應圖,因此可以取得更好的結果。
可能的改進
可能的改進包括多個方面:
- 使用其他更新更好的模型。由於我是這方面的外行,一方面不清楚哪個模型目前最好,另一方面再調研和實現一個模型也需要耗費大量時間;
- 調參。每個模型都涉及很多引數,除此之外還有很多和模型無關的引數,例如學習率、批大小、正則項係數等,由於我是這方面的外行,暫時不清楚對關鍵點檢測這類問題該如何選擇引數;
- 使用資料增強等技巧。由於我是這方面的外行,暫時不清楚除了資料增強之外還有什麼適合關鍵點檢測這類問題的技巧。
雖然有很多可能的改進方向,不過由於自己之前沒有做過關鍵點檢測這類問題,所以繼續折騰下去只能靠運氣各種嘗試,而且每次嘗試都需要等待很久的模型訓練時間。
相比之下,對於一些在關鍵點檢測領域有相當積累的團隊和個人,他們有著豐富的經驗和現成的程式碼,和他們競爭還是相當有難度的。看一下排行榜,第一名的4.49%至今無人能超越,前三十名也都在12%以下。
而且個人事情也比較多,時間和精力都十分有限,所以決定這個比賽不再繼續嘗試,感覺做到這一步就可以了。
總結
通過這次比賽,瞭解了關鍵點檢測這類問題的一些解決方法,並嘗試用CPM進行了一些實踐,對自己而言已經滿足了。
等比賽結束後,再關注一下冠軍團隊的解決方案,好好學習一波。