深度學習高頻手撕程式碼

forrestr發表於2024-06-05

nms

def cal_iou(bbox1,bbox2):
    # x1,y1,x2,y2
    # min_x - max_x
    inter_x = min(bbox1[2],bbox2[2]) - max(bbox1[0],bbox2[0])
    # min_y - max_y
    inter_y = min(bbox1[3],bbox2[3]) - max(bbox1[1],bbox2[1])
    if inter_x <=0 or inter_y <=0:
        return 0
    inter_area = inter_x * inter_y
    area1 = (bbox1[2] - bbox1[0] +1)* (bbox1[3] - bbox1[1] +1)
    area2 = (bbox2[2] - bbox2[0] +1)* (bbox2[3] - bbox2[1] +1)
    #分母不會是0,面積最小是1
    iou = inter_area / (area1+area2 - inter_area)
    return iou


def nms(bboxes,scores,iou_thre=0.3):
    infos = list(map(list,zip(scores,bboxes)))
    print(infos)
    infos.sort(key=lambda x:x[0],reverse=True)
    #print(infos)
    for i in range(len(infos)):
        for j in range(i+1,len(infos)):
            iou = cal_iou(infos[i][1],infos[j][1])      
            print(iou)          
            if iou > iou_thre:
                #score置-1,後面清空
                infos[j][0] = -1
    #剔除score=-1的框
    new_bboxes = []
    for i in range(len(infos)):
        if infos[i][0] != -1:
            new_bboxes.append(infos[i][1])
    return new_bboxes

test_bboxes = [[0,0,100,100],[10,10,100,100],[100,100,200,200]]
test_scores = [0.98,0.3,0.7]
res = nms(test_bboxes,test_scores)
print(res)

相關文章