SAM視覺大模型的finetune

阳光天气發表於2024-05-23

隨著Meta釋出的Segment Anything Model(SAM),計算機視覺迎來了ChatGPT時刻。SAM經過超過110億個分割掩碼的訓練,是預測性人工智慧用例而非生成性人工智慧的基礎模型。雖然它在廣泛的影像模式和問題空間上表現出了令人難以置信的靈活性,但它的釋出沒有“微調”功能。
本教程將概述使用掩碼解碼器微調SAM的一些關鍵步驟,特別是描述SAM的哪些函式用於預/後處理資料,使其處於良好的微調狀態。
What is the Segment Anything Model (SAM)?
分段任意模型(SAM)是Meta AI開發的一個分段模型。它被認為是計算機視覺的第一個基礎模型。SAM是在包含數百萬張影像和數十億個口罩的龐大資料庫上進行訓練的,這使得它非常強大。顧名思義,SAM能夠為各種影像生成準確的分割掩模。SAM的設計使其能夠將人工提示考慮在內,使其對“迴圈中的人工”註釋特別強大。這些提示可以是多模式的:它們可以是要分割的區域上的點、要分割的物件周圍的邊界框,或者關於應該分割的內容的文字提示。
該模型分為三個部分:影像編碼器、提示編碼器和掩碼解碼器。

影像編碼器為被分割的影像生成嵌入,而提示編碼器為提示生成嵌入。影像編碼器是該模型的一個特別大的元件。這與基於嵌入預測分割掩碼的輕量級掩碼解碼器形成對比。Meta AI已經將在Segment Anything 10 Billion Mask(SA-1B)資料集上訓練的模型的權重和偏差作為模型檢查點。
在解釋者部落格文章中瞭解更多關於Segment Anything如何工作的資訊:https://encord.com/blog/segment-anything-model-explained/
What is Model Fine-Tuning?
公開提供的最先進的模型具有自定義架構,通常提供預先訓練的模型權重。如果這些體系結構是在沒有權重的情況下提供的,那麼使用者將需要從頭開始訓練模型,他們將需要使用大量資料集來獲得最先進的效能。
模型微調是採用預先訓練好的模型(架構+權重)並向其顯示特定用例的資料的過程。這通常是模型以前從未見過的資料,或者在其原始訓練資料集中代表性不足的資料。
微調模型和從頭開始之間的區別在於權重和偏差的起始值。如果我們從頭開始訓練,這些將根據一些策略隨機初始化。在這樣的啟動配置中,模型將對手頭的任務“一無所知”,並表現不佳。透過使用預先存在的權重和偏差作為起點,我們可以“微調”權重和偏差,以便我們的模型在自定義資料集上更好地工作。例如,學會識別貓的資訊(邊緣檢測、計數爪子)將有助於識別狗。
Why Would I Fine-Tune a Model?
微調模型的目的是在預先訓練的模型以前沒有看到的資料上獲得更高的效能。例如,在從手機攝像頭收集的大量資料上訓練的影像分割模型將主要從水平角度看到影像。
如果我們試圖將這個模型用於從垂直角度拍攝的衛星影像,它可能不會表現得那麼好。如果我們試圖分割屋頂,該模型可能不會產生最佳結果。預訓練是有用的,因為模型通常已經學會了如何分割物件,所以我們希望利用這個起點來構建一個可以準確分割屋頂的模型。此外,我們的自定義資料集可能沒有數百萬個示例,因此我們希望進行微調,而不是從頭開始訓練模型。
微調是可取的,這樣我們就可以在特定的用例中獲得更好的效能,而不必承擔從頭開始訓練模型的計算成本。
How to Fine-Tune Segment Anything Model [With Code]
背景與架構
我們在介紹部分概述了SAM體系結構。影像編碼器具有具有許多引數的複雜結構。為了微調模型,我們有必要關注掩碼解碼器,它重量輕,因此更容易、更快、更高效地進行微調。
為了微調SAM,我們需要提取其架構的底層部分(影像和提示編碼器、掩碼解碼器)。我們無法使用SamPredictor.predict(連結):
我們只想微調掩碼解碼器
這個函式呼叫SamPredictor.predict_tarch,它有@torch.no_grad()裝飾器(連結),它阻止我們計算梯度
因此,我們需要檢查SamPredictor.prpredict函式,並在我們想要微調的部分(掩碼解碼器)啟用梯度計算的情況下呼叫適當的函式。這樣做也是瞭解更多SAM如何工作的好方法。
Creating a Custom Dataset
我們需要三件事來微調我們的模型:
要在其上繪製分割的影像
分割地面實況掩碼
提示輸入到模型中
我們選擇了印章驗證資料集(連結),因為它有SAM在其訓練中可能沒有看到的資料(即,檔案上的印章)。我們可以透過使用預先訓練的權重執行推理來驗證它在該資料集上的表現良好,但並不完美。ground truth masks也非常精確,這將使我們能夠計算出準確的損失。最後,這個資料集包含分割掩碼周圍的邊界框,我們可以將其用作SAM的提示。下面顯示了一個示例影像。這些邊界框與人工註釋器在生成分段時要經過的工作流程非常一致。

Input Data Preprocessing
我們需要對從numpy陣列到pytorch張量的掃描進行預處理。要做到這一點,我們可以遵循SamPredictor.set_image(連結)和預處理影像的SamPredictor.set_arch_image(連結)內部發生的情況。首先,我們可以使用utils.transform。ResizeLongestSide可調整影像的大小,因為這是預測器(連結)內部使用的轉換器。然後,我們可以將影像轉換為pytorch張量,並使用SAM預處理方法(連結)完成預處理。
Training Setup
我們下載vit_b模型的模型檢查點,並將其載入到:
sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')
我們可以使用預設值設定Adam最佳化器,並指定要調整的引數是掩碼解碼器的引數:
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters())
同時,我們可以設定我們的損失函式,例如均方誤差
loss_fn = torch.nn.MSELoss()
Training Loop
在主訓練迴圈中,我們將迭代我們的資料項,生成掩碼,並將它們與我們的ground truth掩碼進行比較,以便我們可以基於損失函式最佳化模型引數。
在這個例子中,我們使用GPU進行訓練,因為它比使用CPU快得多。在適當的張量上使用.to(裝置)是很重要的,以確保CPU上沒有某些張量,GPU上沒有其他張量。
我們希望透過將編碼器封裝在torch.no.grad()上下文管理器中來嵌入影像,因為否則我們將出現記憶體問題,同時我們不希望微調影像編碼器。
with torch.no_grad(): image_embedding = sam_model.image_encoder(input_image)
我們還可以在no.grad上下文管理器中生成提示嵌入。我們使用邊界框座標,轉換為pytorch張量。
with torch.no_grad(): sparse_embeddings, dense_embeddings = sam_model.prompt_encoder( points=None, boxes=box_torch, masks=None, )
最後,我們可以生成遮罩。請注意,這裡我們處於單掩碼生成模式(與正常輸出的3個掩碼形成對比)。
low_res_masks, iou_predictions = sam_model.mask_decoder( image_embeddings=image_embedding, image_pe=sam_model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, )
這裡的最後一步是將遮罩升級回原始影像大小,因為它們的解析度較低。我們可以使用Sam.postprocess_masks來實現這一點。我們還希望從預測的掩碼中生成二進位制掩碼,以便將其與我們的基本事實進行比較。為了不破壞反向傳播,使用torch泛函是很重要的。

點選檢視程式碼
upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)

from torch.nn.functional import threshold, normalize

binary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)
我們可以計算損失並執行最佳化步驟:
點選檢視程式碼
loss = loss_fn(binary_mask, gt_binary_mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
透過在多個時期和批次上重複這一過程,我們可以微調SAM解碼器。 **Saving Checkpoints and Starting a Model from it** 一旦我們完成了訓練並對效能提升感到滿意,我們就可以使用以下方法儲存調整模型的狀態dict: `torch.save(model.state_dict(), PATH)` 然後,當我們想對與我們用來微調模型的資料相似的資料執行推理時,我們可以載入這個狀態dict。 未完待續

SAM論文連結:https://arxiv.org/pdf/2304.02643

相關文章