Segment-anything學習到微調系列3_SAM微調decoder

xzyun2011發表於2024-07-29

前言

本系列文章是博主在工作中使用SAM模型時的學習筆記,包含三部分:

  1. SAM初步理解,簡單介紹模型框架,不涉及細節和程式碼
  2. SAM細節理解,對各模組結合程式碼進一步分析
  3. SAM微調例項,原始程式碼涉及隱私,此部分使用公開的VOC2007資料集,Point和Box作為提示進行mask decoder微調講解

本篇是第3部分,基於voc2007資料集對SAM decoder進行微調。程式碼已上傳至github,如果對你有幫助請點個Star,感謝。

此前講過,以ViT_B為基礎的SAM權重是375M,其中prompt encoder只有32.8k,mask decoder是16.3M(4.35%),剩餘則是image encoder,image encoder是非常大的,一般不對它進行微調,預訓練的已經夠好了,除非是類似醫療影像這種非常規資料,預訓練資料中沒有,效果會比較差,才會對image encoder也進行微調,所以此處只針對decoder進行微調。

微調效果

基於point prompt

這部分是隻針對point作為提示的微調,藉助了ISAT_with_segment_anything這個用SAM做自動標註的工具來進行一個效果比對,可以看出來微調前,需要點選多次多個點才能分割得較好,微調後點選一下就能分割出對應類別

微調前

微調後

基於box prompt

這部分加入了box作為提示的微調

微調前

微調後

程式碼部分

資料讀取

使用的是VOC2007分割資料集,總共632張圖片(412train_val,210test),一共20個類別,加上背景類一共21,標籤是png格式,畫素值代表物體類別,同時所有物體mask的外輪廓值是255,訓練時會忽略,原始資料集如下目錄構造(github上的程式碼中data_example只是示例,只有幾張圖),訓練使用的是SegmentationObject中的標籤:

## VOCdevkit/VOC2007
├── Annotations
├── ImageSets
│   ├── Layout
│   ├── Main
│   └── Segmentation
├── JPEGImages
├── SegmentationClass
└── SegmentationObject

CustomDataset的程式碼按如上目錄結構讀取對應資料,根據ImageSets/Segmentation目錄下的txt_name指定訓練的檔名字,然後讀取對應圖片和標籤,有以下幾點注意:

  • 分割標籤使用PIL讀取,畫素值就是對應類別,255是外輪廓會忽略;如果使用opencv讀取圖片,需要根據RGB值去platte表中看對應類別
  • image和gt都是按numpy array塞進batch中,後面丟給sam會轉為tensor;voc2007中每張圖片大小是不一致的,目前就按batch=1處理
  • gt的channel是1,後面需要轉為one-hot的形式
class CustomDataset(Dataset):
    def __init__(self, VOCdevkit_path, txt_name="train.txt", transform=None):
        self.VOCdevkit_path = VOCdevkit_path
        with open(os.path.join(VOCdevkit_path, f"VOC2007/ImageSets/Segmentation/{txt_name}"), "r") as f:
            file_names = f.readlines()
        self.file_names = [name.strip() for name in file_names]
        self.image_dir = os.path.join(self.VOCdevkit_path, "VOC2007/JPEGImages")
        self.image_files = [f"{self.image_dir}/{name}.jpg" for name in self.file_names]
        self.gt_dir = os.path.join(self.VOCdevkit_path, "VOC2007/SegmentationObject")
        self.gt_files = [f"{self.gt_dir}/{name}.png" for name in self.file_names]

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        image_name = image_path.split("/")[-1]
        gt_path = self.gt_files[idx]

        image = cv2.imread(image_path)
        image = image[..., ::-1] ## RGB to BGR
        image = np.ascontiguousarray(image)
        gt = Image.open(gt_path)
        gt = np.array(gt, dtype='uint8')
        gt = np.ascontiguousarray(gt)

        return image, gt, image_name

    @staticmethod
    def custom_collate(batch):
        """ DataLoader中collate_fn,
         影像和gt都用numpy格式,後面會重新轉tensor
        """
        images = []
        seg_labels = []
        images_name = []
        for image, gt, image_name in batch:
            images.append(image)
            seg_labels.append(gt)
            images_name.append(image_name)
        images = np.array(images)
        seg_labels = np.array(seg_labels)
        return images, seg_labels, images_name

影像預處理

取得影像後,直接使用SamPredictor中的預處理方式,會將圖片按最長邊resized到1024x1024,然後計算image_embedding,這部分很耗時,所以每張圖只計算一次,會將結果快取起來需要的時候直接呼叫。使用"with torch.no_grad()"保證image encoder部分不需要梯度更新,凍結對應權重

    model_transform = ResizeLongestSide(sam.image_encoder.img_size)
    for epoch in range(num_epochs):
        epoch_loss = 0
        for idx, (images, gts, image_names) in enumerate(tqdm(dataloader)):
            valid_classes = []  ## voc 0,255 are ignored
            for i in range(images.shape[0]):
                image = images[i] # h,w,c np.uint8 rgb
                original_size = image.shape[:2] ## h,w
                input_size = model_transform.get_preprocess_shape(image.shape[0], image.shape[1],
                                                                  sam.image_encoder.img_size)  ##h,w
                gt = gts[i].copy() #h,w labels [0,1,2,..., classes-1]
                gt_classes = np.unique(gt)  ##masks classes: [0, 1, 2, 3, 4, 7]
                image_name = image_names[i]

                predictions = []
                ## freeze image encoder
                with torch.no_grad():
                    # gt_channel = gt[:, :, cls]
                    predictor.set_image(image, "RGB")
                    image_embedding = predictor.get_image_embedding()

Prompt生成

從mask中隨機選取一定數量的前景點和背景點,此處預設1個前景點和1個背景點,數量多的話一般保持2:1的比例較好。

mask_value就是對應的類別id,去mask中找出畫素值等於類別id的點座標,然後隨機選取點就行。此處還會根據mask算外接矩形(實際上直接讀取圖片對應的xml標籤檔案也行),用於後續基於box prompt的finetune。

def get_random_prompts(mask, mask_value, foreground_nums=1, background_nums=1):
    # Find the indices (coordinates) of the foreground pixels
    foreground_indices = np.argwhere(mask == mask_value)
    ymin, xmin= foreground_indices.min(axis=0)
    ymax, xmax = foreground_indices.max(axis=0)
    bbox = np.array([xmin, ymin, xmax, ymax])
    if foreground_indices.shape[0] < foreground_nums:
        foreground_nums = foreground_indices.shape[0]
        background_nums = int(0.5 * foreground_indices.shape[0])
    background_indices = np.argwhere(mask != mask_value)

    ## random select
    foreground_points = foreground_indices[
        np.random.choice(foreground_indices.shape[0], foreground_nums, replace=False)]
    background_points = background_indices[
        np.random.choice(background_indices.shape[0], background_nums, replace=False)]

    ## 座標點是(y,x),輸入給網路應該是(x,y),需要翻一下順序
    foreground_points = foreground_points[:, ::-1]
    background_points = background_points[:, ::-1]

    return (foreground_points, background_points), bbox

得到的prompt是一些點的座標,座標的x,y是基於原圖的,但進入SAM的圖片會resized到1024x1024,所以點座標也需要resize,對應如下程式碼

    all_points = np.concatenate((foreground_points, background_points), axis=0)
    all_points = np.array(all_points)
    point_labels = np.array([1] * foreground_points.shape[0] + [0] * background_points.shape[0], dtype=int)
    ## image resized to 1024, points also
    all_points = model_transform.apply_coords(all_points, original_size)

    all_points = torch.as_tensor(all_points, dtype=torch.float, device=device)
    point_labels = torch.as_tensor(point_labels, dtype=torch.float, device=device)
    all_points, point_labels = all_points[None, :, :], point_labels[None, :]
    points = (all_points, point_labels)

    if not box_prompt:
        box_torch=None
    else:
        ## preprocess bbox
        box = model_transform.apply_boxes(bbox, original_size)
        box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
        box_torch = box_torch[None, :]

微調程式碼中可以指定基於哪種prompt進行微調,如果是point和box同時都開,會按一定機率捨棄point或box以取得更好的泛化性(不然推理時只有point或只有box作為prompt效果可能不太好)。最後經過prompt_encoder得到sparse_embeddings, dense_embeddings。

    ## if both, random drop one for better generalization ability
    if point_box and np.random.random()<0.5:
        if np.random.random()<0.25:
            points = None
        elif np.random.random()>0.75:
            box_torch = None
    ## freeze prompt encoder
    with torch.no_grad():
        sparse_embeddings, dense_embeddings = sam.prompt_encoder(
            points = points,
            boxes = box_torch,
            # masks=mask_predictions,
            masks=None,
        )

Mask預測

mask decoder這部分不需要凍結,直接呼叫mask_decoder推理就行,這裡進行了兩次mask預測,第一次先預測3個層級的mask然後選出得分最高的一個,將這個mask作為一個mask prompt,並與point prompt、box_prompt一起丟進prompt_encoder得到新的sparse_embeddings, dense_embeddings,再進行第二次mask預測,這次只預測一個mask。相當於先得到粗糙的mask,然後再精修。最後經過後處理nms等得到和原圖大小一樣的預測mask,一個物體對應一張mask,將多個mask疊起來就得到這張圖所有的預測結果predictions。

    ## predicted masks, three level
    mask_predictions, scores = sam.mask_decoder(
        image_embeddings=image_embedding.to(device),
        image_pe=sam.prompt_encoder.get_dense_pe(),
        sparse_prompt_embeddings=sparse_embeddings,
        dense_prompt_embeddings=dense_embeddings,
        multimask_output=True,
    )
    # Choose the model's best mask
    mask_input = mask_predictions[:, torch.argmax(scores),...].unsqueeze(1)
    with torch.no_grad():
        sparse_embeddings, dense_embeddings = sam.prompt_encoder(
            points=points,
            boxes=box_torch,
            masks=mask_input,
        )
        ## predict a better mask, only one mask
        mask_predictions, scores = sam.mask_decoder(
            image_embeddings=image_embedding.to(device),
            image_pe=sam.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
        )
        best_mask = sam.postprocess_masks(mask_predictions, input_size, original_size)
        predictions.append(best_mask)

Loss計算

程式碼中loss用的是BCELoss加DiceLoss,需要gt和pred的shape一致,都為BxCxHxW的形式,pred是經過sigmoid後的值。

因此需要將gt轉為one-hot的形式,即將(batch_size, 1, h, w)轉為(batch_size, c, h, w),c是gt_classes中有的類別個數,即圖片中有多少個例項類別。

def mask2one_hot(label, gt_classes):
    """
    label: 標籤影像 # (batch_size, 1, h, w)
    num_classes: 分類類別數
    """
    current_label = label.squeeze(1) # (batch_size, 1, h, w) ---> (batch_size, h, w)
    batch_size, h, w = current_label.shape[0], current_label.shape[1], current_label.shape[2]
    one_hots = []
    for cls in gt_classes:
        if isinstance(cls, torch.Tensor):
            cls = cls.item()
        tmplate = torch.zeros(batch_size, h, w)  # (batch_size, h, w)
        tmplate[current_label == cls] = 1
        tmplate = tmplate.view(batch_size, 1, h, w)  # (batch_size, h, w) --> (batch_size, 1, h, w)
        one_hots.append(tmplate)
    onehot = torch.cat(one_hots, dim=1)
    return onehot

另外BCE接受的pred值是logit形式,所以需要將predictions用sigmoid處理,後續loss計算對應如下程式碼

    gts = torch.from_numpy(gts).unsqueeze(1) ## BxHxW ---> Bx1xHxW
    gts_onehot = mask2one_hot(gts, valid_classes)
    gts_onehot = gts_onehot.to(device)

    predictions = torch.sigmoid(predictions)
    # #loss = seg_loss(predictions, gts_onehot)
    loss = BCEseg(predictions, gts_onehot)
    loss_dice = soft_dice_loss(predictions, gts_onehot, smooth = 1e-5, activation='none')
    loss = loss + loss_dice

權重儲存

optimizer預設是AdamW,scheduler是CosineAnnealingLR,這些可以自己修改。最後儲存的權重只儲存當前loss最小的,而且只儲存decoder部分的權重,可以按需修改

if epoch_loss < best_loss:
    best_loss = epoch_loss
    mask_decoder_weighs = sam.mask_decoder.state_dict()
    mask_decoder_weighs = {f"mask_decoder.{k}": v for k,v in mask_decoder_weighs.items() }
    torch.save(mask_decoder_weighs, os.path.join(save_dir, f'sam_decoder_fintune_{str(epoch+1)}_pointbox_monai.pth'))
    print("Saving weights, epoch: ", epoch+1)

全系列完,感謝閱讀...

相關文章