yolov5 head原始碼深度解析
在上次的文章中我們解析了backbone網路的構建原始碼,在這篇中我們針對model.py剩餘的部分進行debug解析。
今天我們繼續對model.py裡的Detect類進行解析,這部分對應yolov5的檢查頭部分。
detect類在model.py裡,這部分程式碼如下:
class Detect(nn.Module):
stride = None # strides computed during build
export = False # onnx export
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
super(Detect, self).__init__()
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor 85 for coco
self.nl = len(anchors) # number of detection layers 3
self.na = len(anchors[0]) // 2 # number of anchors 3
self.grid = [torch.zeros(1)] * self.nl # init grid
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
self.register_buffer('anchors', a) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv 128=>255/256=>255/512=>255
def forward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
z.append(y.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1), x)
@staticmethod
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
我們首先來看這個類的__init__()函式:
def __init__(self, nc=80, anchors=(), ch=()): # detection layer
super(Detect, self).__init__()
self.nc = nc # number of classes
self.no = nc + 5 # number of outputs per anchor 85 for coco
self.nl = len(anchors) # number of detection layers 3
self.na = len(anchors[0]) // 2 # number of anchors 3
self.grid = [torch.zeros(1)] * self.nl # init grid
a = torch.tensor(anchors).float().view(self.nl, -1, 2)
self.register_buffer('anchors', a) # shape(nl,na,2)
self.register_buffer('anchor_grid', a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv 128=>255/256=>255/512=>255
yolov5的檢測頭仍為FPN接面構,所以self.m為3個輸出卷積。這三個輸出卷積模組的channel變化分別為128=>255|256=>255|512=>255。
self.no為每個anchor位置的輸出channel維度,每個位置都預測80個類(coco)+ 4個位置座標xywh + 1個confidence score。所以輸出channel為85。每個尺度下有3個anchor位置,所以輸出85*3=255個channel。
下面我們再來看下head部分的forward()函式:
def forward(self, x):
# x = x.copy() # for profiling
z = [] # inference output
self.training |= self.export
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
z.append(y.view(bs, -1, self.no))
return x if self.training else (torch.cat(z, 1), x)
x是一個列表的形式,分別對應著3個head的輸入。它們的shape分別為:
[B, 128, 32, 32]
[B, 256, 16, 16]
[B, 512, 8, 8]
三個輸入先後被送入了3個卷積,得到輸出結果。
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
這裡將x進行變換從:
x[0]:(bs,255,32,32) => x(bs,3,32,32,85)
x[1]:(bs,255,32,32) => x(bs,3,16,16,85)
x[2]:(bs,255,32,32) => x(bs,3,8,8,85)
if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
def _make_grid(nx=20, ny=20):
yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
這裡的_make_grid()函式是準備好格點。所有的預測的單位長度都是基於grid層面的而不是原圖。注意每一層的grid的尺寸都是不一樣的,和每一層輸出的尺寸w,h是一樣的。
y = x[i].sigmoid()
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
z.append(y.view(bs, -1, self.no))
這裡是inference的核心程式碼,我們要好好剖析一下,相比於yolov3,yolov5有一些變化:
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i]
這裡可以明顯發現box center的x,y的預測被乘以2並減去了0.5,所以這裡的值域從yolov3裡的(0,1)注意是開區間,變成了(-0.5, 1.5)。
這樣改的原因目前還未知,從表面理解是可以跨半個格點預測了,這樣應該能提高一些召回。當然還有一個好處就是也解決了yolov3中因為sigmoid開區間而導致中心無法到達邊界處的問題。這裡是我分析的觀點,如果讀者有其他的思路歡迎留言點撥。
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]
這裡是預測boundingbox的wh。先回顧下yolov3裡的預測:
pred_boxes[..., 0] = x.data + self.grid_x
pred_boxes[..., 1] = y.data + self.grid_y
pred_boxes[..., 2] = torch.exp(w.data) * self.anchor_w
pred_boxes[..., 3] = torch.exp(h.data) * self.anchor_h
是一個基於框的w,h的e指數函式。而在yolov5中這裡變成了:(2*w_pred/h_pred) ^2。
值域從原來的(1,e)變成了(0,4)。這裡我的理解是這個預測的框範圍變得更大了,不僅可以預測到4倍以內的大物體,而且可以預測到比anchor小的boundingbox。和上面一樣,這裡是我分析的觀點,如果讀者有其他的思路歡迎留言點撥。
到這裡我們就分析完了Detect類裡面的所有程式碼。下面我回到Model類裡面,最後分析它的前向傳播過程,這裡有兩個函式forward()和forward_once()兩個函式:
def forward_once(self, x, profile=False):
y, dt = [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
if profile:
o = thop.profile(m, inputs=(x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPS
t = time_synchronized()
for _ in range(10):
_ = m(x)
dt.append((time_synchronized() - t) * 100)
print('%10.1f%10.0f%10.1fms %-40s' % (o, m.np, dt[-1], m.type))
x = m(x) # run
y.append(x if m.i in self.save else None) # save output
if profile:
print('%.1fms total' % sum(dt))
return x
self.foward_once()就是前向執行一次model裡的所有module,得到結果。profile引數開啟會記錄每個模組的平均執行時長和flops用於分析模型的瓶頸,提高模型的執行速度和降低視訊記憶體佔用。
def forward(self, x, augment=False, profile=False):
if augment: 大連婦科醫院哪個好
img_size = x.shape[-2:] # height, width
s = [1, 0.83, 0.67] # scales
f = [None, 3, None] # flips (2-ud, 3-lr)
y = [] # outputs
for si, fi in zip(s, f):
xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
yi = self.forward_once(xi)[0] # forward
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
yi[..., :4] /= si # de-scale
if fi == 2:
yi[..., 1] = img_size[0] - 1 - yi[..., 1] # de-flip ud
elif fi == 3:
yi[..., 0] = img_size[1] - 1 - yi[..., 0] # de-flip lr
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train
else:
return self.forward_once(x, profile) # single-scale inference, train
self.forward()函式里面augment可以理解為控制TTA,如果開啟會對圖片進行scale和flip。預設是關閉的。
def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
# scales img(bs,3,y,x) by ratio constrained to gs-multiple
if ratio == 1.0:
return img
else:
h, w = img.shape[2:]
s = (int(h * ratio), int(w * ratio)) # new size
img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
if not same_shape: # pad/crop img
h, w = [math.ceil(x * ratio / gs) * gs for x in (h, w)]
return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
scale_img的原始碼如上,就是透過普通的雙線性插值實現,根據ratio來控制圖片的縮放比例,最後透過pad 0補齊到原圖的尺寸。
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69945560/viewspace-2761956/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- SnapHelper原始碼深度解析原始碼
- Vuex 原始碼深度解析Vue原始碼
- OkHttp原始碼深度解析HTTP原始碼
- React Hooks原始碼深度解析ReactHook原始碼
- Netty原始碼深度解析(九)-編碼Netty原始碼
- Spring原始碼深度解析(郝佳)-學習-原始碼解析-Spring MVCSpring原始碼MVC
- RecyclerView用法和原始碼深度解析View原始碼
- KubeSphere 後端原始碼深度解析後端原始碼
- 深度解析 create-react-app 原始碼ReactAPP原始碼
- spring原始碼深度解析— IOC 之 bean 建立Spring原始碼Bean
- spring原始碼深度解析— IOC 之 自定義標籤解析Spring原始碼
- Spring原始碼深度解析(郝佳)-學習-原始碼解析-基於註解注入(二)Spring原始碼
- spring原始碼深度解析— IOC 之 屬性填充Spring原始碼
- 原始碼深度解析 Handler 機制及應用原始碼
- spring原始碼深度解析— IOC 之 預設標籤解析(上)Spring原始碼
- spring原始碼深度解析— IOC 之 預設標籤解析(下)Spring原始碼
- JVM CPU Profiler技術原理及原始碼深度解析JVM原始碼
- spring原始碼深度解析— IOC 之 容器的基本實現Spring原始碼
- spring原始碼深度解析— IOC 之 bean 的初始化Spring原始碼Bean
- 深度解析Spring Cloud Ribbon的實現原始碼及原理SpringCloud原始碼
- spring原始碼深度解析— IOC 之 迴圈依賴處理Spring原始碼
- spring原始碼深度解析— IOC 之 開啟 bean 的載入Spring原始碼Bean
- Spring5原始碼深度解析(一)之理解Configuration註解Spring原始碼
- Java Timer原始碼解析(定時器原始碼解析)Java原始碼定時器
- 【原始碼解析】- ArrayList原始碼解析,絕對詳細原始碼
- 深入原始碼,深度解析Java 執行緒池的實現原理原始碼Java執行緒
- [原始碼解析] 深度學習分散式訓練框架 horovod (8) --- on spark原始碼深度學習分散式框架Spark
- [原始碼解析] 深度學習分散式訓練框架 horovod (7) --- DistributedOptimizer原始碼深度學習分散式框架
- ReactNative原始碼解析-初識原始碼React原始碼
- Toast原始碼深度分析AST原始碼
- 深度解剖dubbo原始碼原始碼
- Koa 原始碼解析原始碼
- Koa原始碼解析原始碼
- RxPermission原始碼解析原始碼
- Express原始碼解析Express原始碼
- redux原始碼解析Redux原始碼
- CopyOnWriteArrayList原始碼解析原始碼
- LeakCanary原始碼解析原始碼