yolov3 anchors k-means指令碼

HockerF發表於2020-11-18

最近再回過頭了看了下經典的 yolov3 演算法, 裡面的 anchors 聚類是用的 k-means . 下面是 python的一個實現, 大家可以參考討論.

import os, sys 
import numpy as np 
import cv2 

if len(sys.argv) < 5:
    print('usage:\n\t%s datacfg num_of_clusters width height')
    sys.exit(-1)

datacfg = sys.argv[1]
num_of_clusters = int(sys.argv[2])
width = int(sys.argv[3])
height = int(sys.argv[4])

train_file = None 
with open(datacfg) as fr:
    dc = fr.readlines()
    for t in dc:
        ts = t.split('=')

        if ts[0].strip()== 'train' and len(ts) > 1:
            train_file = ''.join(ts[1:]).strip()
            break
    if train_file is None:
        print('can\'t find train in data file')
        sys.exit(-1)

number_of_boxes = 0 
r_wh_arr = []

with open(train_file) as fr:
    dc = fr.readlines()
    for i, t in enumerate(dc):
        t = os.path.abspath(t.strip())
        t1 = os.path.join(os.path.dirname(os.path.dirname(t)),'labels',\
            os.path.basename(os.path.splitext(t)[0])+'.txt')
        if not os.path.exists(t1):
            print('can\'t find %s'%(t1))
            continue
        with open(t1) as lab_fr:
            for lab_l in lab_fr.readlines():
                lt = lab_l.split()
                if len(lt) != 5:
                    print('wrong label:',t1)
                    continue
                for ti in range(1,len(lt)):
                    lt[ti] = float(lt[ti])
                lb_n,lb_x,lb_y,lb_w,lb_h= lt 

                if  lb_x > 1 or lb_x <= 0 or \
                    lb_y > 1 or lb_y <= 0 or \
                    lb_w > 1 or lb_w <= 0 or \
                    lb_h > 1 or lb_h <= 0 :
                    print('wrong label:',t1)
                    continue
                number_of_boxes += 1 
                r_wh_arr.append([lb_w*width,lb_h*height])
                print("\r loaded \t image: %d \t box: %d"%(i+1, number_of_boxes),end='')

    print("\n all loaded. ")
    
criteria = (cv2.TERM_CRITERIA_EPS + 
            cv2.TERM_CRITERIA_MAX_ITER, 10000, 0)#  TERM_CRITERIA_MAX_ITER

flags = cv2.KMEANS_PP_CENTERS
data = np.float32(np.array(r_wh_arr))
compactness, labels, centers = cv2.kmeans(data, num_of_clusters, None, criteria, 10, flags) 

print('anchors = ',end='')
for p in sorted(centers.tolist(),key=lambda x: x[0]*x[1]):
    print('%.4f,%.4f, '%(p[0],p[1]),end='')
 loaded          image: 792      box: 939
 all loaded. 
anchors = 21.2969,11.8091, 17.5623,17.6491, 16.9178,24.6314, 22.1751,20.8706, 32.5937,16.0246, 21.6062,30.9830, 17.6986,39.9335, 74.9313,148.7542, 150.5354,153.2271

相關文章