深度學習: Non-Maximum Supression (非極大值抑制)

Yu.Mu發表於2018-01-05

NMS (Non-Maximum Supression)

NMS來 選取那些鄰域裡分數最高的視窗,同時抑制那些分數低的視窗

論文可見,在 Faster R-CNN 中,NMS演算法被放在RPN網路的末段,用於 協助 剔除低得分的box
這裡寫圖片描述

Note:

  • 所有NMS演算法都是在每個類內分別獨立進行NMS。

  • NMS演算法略顯粗暴,直接將和得分最大的box的IOU大於某個閾值的box的 得分置零 或者 丟棄box 。後續的改良版——Soft NMS,改用稍低一點的分數 ( score*(1-iou) ) 來代替原有的分數,而不是直接置零。

在 Faster R-CNN 中位於下圖中的 綠框位置:
這裡寫圖片描述

Test

我經過動手實驗,成功復現了 NMS 的處理過程。

未經過NMS之前的bbox分佈:
這裡寫圖片描述

經過NMS篩選後的保留的bbox分佈:
這裡寫圖片描述

Code

效果圖所對應的原始碼如下:

# coding=utf-8

import numpy as np
import cv2


def nms(bboxs, thresh):

    x1, y1, x2, y2, scores = list([bboxs[:, i] for i in range(len(bboxs[0]))])
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)    # 每塊 bbox 面積

    order = scores.argsort()[::-1]    # 所有 bbox 根據置信度 進行 index排序
    keep = []    # 篩選後 要留下來的 bbox
    while order.size > 0:
        i = order[0]    # 置信度最高的bbox 的 index
        keep.append(i)    # 先 留下 剩下的bbox 中 置信度最高的bbox 的index

        # 選擇大於x1,y1和小於x2,y2的區域
        xx1 = np.maximum(x1[i], x1[order[1:]])
        yy1 = np.maximum(y1[i], y1[order[1:]])
        xx2 = np.minimum(x2[i], x2[order[1:]])
        yy2 = np.minimum(y2[i], y2[order[1:]])
        # 當前bbox 與 每個剩餘的 bbox 分別的 重疊區域
        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h    # [ 12272.  11716.  13566.      0.      0.]    [ 8181.]

        # 交叉區域面積 / (bbox + 某區域面積 - 交叉區域面積)
        overlap = inter / (areas[i] + areas[order[1:]] - inter)    # [ 0.73362028  0.62279396  0.52134814  0.          0.        ]   [ 0.61832061]

        # 保留交集小於一定閾值的boundingbox
        idxs = np.where(overlap <= thresh)[0]    # [3 4]    []

        order = order[idxs + 1]    # [0 3]    []

    return keep


def draw_bbox(bboxs, pic_name):
    pic = np.zeros((850, 850), np.uint8)
    for bbox in bboxs:
        x1, y1, x2, y2 = map(int, bbox[:-1])
        pic = cv2.rectangle(pic, (x1, y1), (x2, y2), (255, 0, 0), 2)
    cv2.imwrite('./{}.jpg'.format(pic_name), pic)


if __name__ == "__main__":
    bboxs = np.array([
        [720, 690, 820, 800, 0.5],
        [204, 102, 358, 250, 0.5],
        [257, 118, 380, 250, 0.8],
        [700, 700, 800, 800, 0.4],
        [280, 135, 400, 250, 0.7],
        [255, 118, 360, 235, 0.7]])
    thresh = 0.3
    draw_bbox(bboxs, "Before_NMS")
    keep = nms(bboxs, thresh)    # [2, 0]
    draw_bbox(bboxs[keep], "After_NMS")

演算法缺陷

NMS演算法的 核心思想 是:在 假設 例項之間均為不重疊或低重疊的 前提 下,去除高重疊bbox。

該核心思想也導致了NMS不可避免地會對 高重疊的例項 產生 漏檢

相關文章