meta在2023.4.5又發了image sematic segmentation的文章,名字就叫Segment Anything;學術圈有個潛規則:title越簡單,事情越大,比如7年前的那篇 attention is all you need,直接提升了nlp的層次!這次的Segment Anything同樣也很簡單,這次又有哪些break through innovation?
1、(1)論文剛開始,給出了模型的互動方式:點、框、mask描邊、text都能作為prompt,然後和image一起輸入,經過model的處理後,輸出就是valid mask了!怎麼樣,是不是很符合人的使用習慣?
另一個靚點:所謂的data engine,先人工標註少量的高質量資料集,用來訓練"粗糙"的SAM;然後用粗糙的SAM做語義分割,期間配個人工檢查標記漏掉的mask,用以完善資料集,再繼續迭代訓練SAM;如此往復,不停迭代,直到得到高質量的資料集和準確率高的model,整個過程的思路是不是和loss+back proportion很相似啊!
2、 SAM網路結構如下:
- 原始的image,透過encoder後轉成image embedding;為了和transformer架構相容,這裡推薦使用vit提取image embedding
- 常見的promt:
- point、box:這兩種prompt都和位置相關,所以可以用positional encoder編碼,論文原話:We represent points and boxes by positional encodings [95] summed with learned embeddings for each prompt type
- text是文字,最常見的就是BERT編碼了,論文用的CLIP,論文原話:free-form text with an off-the-shelf text encoder from CLIP [82]
- 最麻煩的就是mask prompt了:這個是使用者手動的描邊,可能不精準,只是個大概的範圍,需要SAM進一步精確描邊。這裡的mask promt本質也是個圖片,所以這裡用conv提取特徵,轉成embedding,論文原話:Dense prompts (i.e., masks) are embedded using convolutions and summed element-wise with the image embedding
- 四種promt經過encoder編碼後,進入mask decoder,就生成了valid mask圖片;
- 為了避免歧義,model預設輸出三個mask segmentation,比如上圖的剪刀有3個:剪刀全身、剪刀兩個耳朵,剪刀的一個耳朵,每個mask segmentation都有一個socre,計算方法為IOU,score越高,mask segmentation越準確
所以整個網路結構最核心的就是mask encoder了,這個又是怎麼處理image和promt的embedding的了?
3、(1)mask encoder的網路架構如下:
論文原始描述:Figure 14: Details of the lightweight mask decoder. A two-layer decoder updates both the image embedding and prompt tokens via cross-attention. Then the image embedding is upscaled, from which the updated output tokens are used to dynamically predict masks. (Not illustrated for figure clarity: At every attention layer, positional encodings are added to the image embedding, and the entire original prompt token (including position encoding) is re-added to the token queries and keys.)
- image embedding:原始的image透過vit後轉成256 * 64 * 64這麼大,64 * 64 可以看成是H' * W',256這種channel可以看成是embedding
- promt token: 就是point、box、text、mask等編碼後的embedding,每個token也都轉成256維的向量,和image embedding的dimension保持一致
- output token:從名字看,就知道是輸出結果的token了。這部分的token embedding是要動態更新的,論文原話:Then the image embedding is upscaled, from which the updated output tokens are used to dynamically predict masks;instead of predicting a single mask, we use a small number of output tokens and predict multiple masks simultaneously. By default we predict three masks;
輸出準備完畢,就是最關鍵的mask decoder環節了,直接涉及到輸出的mask是否準確;從下網上看:
- self attn:這個容易理解,主要是output token和promt token內部做attn,核心是檢視output和prompt是否接近,來確定output是否正確;
- token to image attn:把image的特徵和tokens的特徵做融合,這裡是token主動和image做cross attn
- MLP:類似transformer block的FFN
- image to token attn:這裡居然是image主動和token做cross attn,這也是SAM的創新點之一;
- To ensure the decoder has access to critical geometric information, the positional encodings are added to the image embedding whenever they participate in an attention layer:模型在mask decode過程中,能保留每個pixel(或特徵位置)的position,避免decode過程中pixel的原始position資訊丟失,從而幫助mask decoder更精確地生成spatial資訊相關的mask,我個人覺得這個思路和renet接近
- Additionally, the entire original prompt tokens (including their positional encodings) are re-added to the updated tokens whenever they participate in an attention layer:無論在attn層中更新了多少次promt,mask decoder都會把最初的promt token 連同其position embedding一起重新注入。這使得模型在decode時既保留了prompt的postion和型別資訊,也允許它在此基礎上增加或調整新的資訊
以上就是組接近transformer block的架構了,比較神奇的是這個block居然只有2個,所以SAM模型的引數也很小,才641M個,bin檔案不到2.4GB(https://huggingface.co/facebook/sam-vit-huge 有詳情)!這麼小的引數,為啥效果這麼好?我個人認為核心還是資料質量好,也就是data engine的思路好!經過2個block,image和tokens的資訊互相融合後,
- 2x conv trans:這裡做upsample ,核心目的是還原原始的image,才能在原始尺寸的image上做mask segmentation
- token to image attn:兩部分的資訊繼續融合
- MLP:只有3層,核心是做channel維度適配;we pass the updated output token embedding to a small 3-layer MLP that outputs a vector matching the channel dimension of the upscaled image embedding
- masks:dot product per mask * output token per mask(經過MLP):這步驟和maskformer中的最後一個drop神似:classification loss * binary mask loss最終得到 K*H*W,也就是每個pixel的K個object分類的機率分佈;Finally, we predict a mask with a spatially point-wise product between the upscaled image embedding and the MLP’s output;這裡的到的是每個pixel屬於N個query/mask的機率分佈(或者說單個N包含哪些pixel,不包含哪些pixel;比如3*H*W,就是第3個mask包含哪些pixel,不包含哪些pixel),也就是N*H*W!
- 最後一個IOU:mask出來的object和標註ibject的交集,越大說明mask越準確
(2) loss:輸出的masks對不對了? 怎麼評價輸出的masks好壞了?這裡涉及到loss的選擇了! masks輸出的是N*H*W,這裡有個問題:比如識別image的行人。這個資料集中有大量的行人影像,只有少數不是行人的影像,所以類別之間的極端不平衡:行人類別(正樣本)遠多於非行人類別(負樣本),所以SAM這裡採用的是focal loss,而不是cross entropy,這裡的公式如下:
這裡最重要的引數就是gamma了!以
- gamma = 0 時,focal loss退化為標準的交叉熵損失。模型可能會很快學會正確分類那些易分類的行人影像,但對於少數非行人影像,由於它們在訓練集中的比例很低,模型可能無法給予足夠的關注,導致對這些難分類樣本的識別效能較差
- gamma > 0 時,focal loss會減少對那些已經分類正確的、易分類的行人樣本的損失貢獻,而增加對分類錯誤的、難分類的非行人樣本的損失貢獻。這樣,模型被迫更多地關注那些難分類的樣本,從而提高了對少數類別的識別能力
舉例:比如gamma=2,對於一個已經被模型以高機率正確分類的行人影像(Pt接近1),那麼(1-Pt)^2接近0,這會顯著減少該樣本的損失貢獻。相反,對於一個被錯誤分類的非行人影像,此時Pt接近0,但是(1-Pt)^2會比較大,這會增加該樣本的損失貢獻,促使模型調整引數以更好地分類這類樣本
上述loss是針對masks的,核心是某個pixel的類別對不對,類似於maskformer的binary mask loss!然而mask segmentation是由大量的pixel組成的,單個pixel分類正確了還不足以說明整個mask是正確的,咋辦咧?還有另一個指標來評判maks整體的準確率有多高,就是IOU score(類似於maskformer的mask classification),圖示如下:
核心就是predicted mask和ground truth mask之間的交集除以並集!score越高,說明predicted mask越接近ground truth mask!用hugging face的https://huggingface.co/facebook/sam-vit-huge訓練好的model嘗試,對車胎標記結果如下:
明顯是mask2的效果最好,因為score最高的嘛!
(3)從SAM的網路結構上看,創新點有:
- image to token attn:之前都是token主動和image做cross attn,這裡增加了image主動和token做cross attn
- 2x conv. trans做upsample,把image還原成原始的尺寸,利於mask segmentation的生成!
其他部分和maskformer基本一樣,沒啥本質區別了!
4、SAM效果好的另一個核心原因: data engine! 讓人工標註大量資料的成本是很高的,怎麼低成本地得到大量的優質標註資料了?論文原話:Our data engine has three stages: assisted-manual, semi-automatic, and fully automatic.
- In the first stage, SAM assists annotators in annotating masks, similar to a classic interactive segmentation setup.
- In the second stage, SAM can automatically generate masks for a subset of objects by prompting it with likely object locations and annotators focus on annotating the remaining objects, helping increase mask diversity.
- In the final stage, we prompt SAM with a regular grid of foreground points, yielding on average ∼100 high-quality masks per image.
概括一下這個所謂data engine迭代的原理其實很簡單:
- 先用少量人工標註的高質量資料訓練SAM。因為資料量少,所以此時SAM的質量很粗糙,不咋地
- 利用粗糙的SAM分割資料,但此時分割的結果質量肯定也不咋地,還是需要人工介入修正分割錯誤的資料,來提升準確率;注意,這步是修正,成本比人工從頭開始標註低,這是關鍵點;這一步的目的是提升accuracy!
- 用修正好的增量資料繼續訓練模型,此時SAM的質量肯定比第一步的好很多。繼續用SAM分割新image,再次用人工修正,找到漏掉的mask,這一步的目的是提升recall!
- 繼續用上一步修正好的資料訓練SAM,此時找到score較高的mask,繼續進一步訓練
- 重複上述的3、4兩個步驟,就能得到越來越多的高質量標註資料啦!
用論文原話:Our final dataset, SA-1B, includes more than 1B masks from 11M licensed and privacy-preserving images:11M張image中得到了1B個高質量的mask;這麼多的mask,如果全讓人工標註,成本豈不是要上天了?
參考:
1、https://www.bilibili.com/video/BV1K94y177Ka/?spm_id_from=333.788.recommend_more_video.0&vd_source=241a5bcb1c13e6828e519dd1f78f35b2
2、https://github.com/facebookresearch/segment-anything https://huggingface.co/facebook/sam-vit-huge
- SAM 2 code: https://github.com/facebookresearch/segment-anything-2
- SAM 2 demo: https://sam2.metademolab.com/
- SAM 2 paper: https://arxiv.org/abs/2408.00714
3、https://www.bilibili.com/video/BV1aL41127VG/?spm_id_from=333.337.search-card.all.click&vd_source=241a5bcb1c13e6828e519dd1f78f35b2
4、https://arxiv.org/abs/2304.02643 Segment Anything
5、https://www.bilibili.com/video/BV13Mm5YDEAW?spm_id_from=333.788.videopod.episodes&vd_source=241a5bcb1c13e6828e519dd1f78f35b2