Pytorch版Faster R-CNN 原始碼分析+方法流程詳解——訓練篇
Pytorch版Faster R-CNN 原始碼分析+方法流程詳解——訓練篇
繼demo篇之後,繼續分析faster rcnn的訓練過程,繼續對Pytorch版的原始碼進行分析。
1、參考文章/部落格/論文/原始碼
論文:https://arxiv.org/pdf/1506.01497.pdf
原始碼:https://github.com/jwyang/faster-rcnn.pytorch
2、環境配置
見demo篇faster rcnn檢測階段原始碼分析
3、資料集及模型
以VGG16作為Backbone網路模型,資料集採用PASCAL VOC2007。
5、faster rcnn方法訓練流程分析(訓練階段)
指令碼檔案:trainval_net.py
1、def parse_args():
引數傳遞函式,定義網路訓練所需的相關引數,方便使用命令傳遞引數。
2、資料迭代器sampler類的定義及實現
該類繼承Sampler類,在torch.utils.data.DataLoader構建資料集迭代器時使用。
sampler類重寫了__iter__與__len__函式,將訓練資料隨機打亂,根據batch size的大小返回索引迭代器。
3、def combined_roidb(): 訓練資料組織方法。
檔案:lib/roi_data_layer/roidb.py
四個返回值分別表示的含義
**imdb:**表示根據voc_2007_trainval定義的pascal voc類物件,該類繼承python的imdb類。
**roidb:**表示影像相關的資訊,dict型別,包含boxes、gt_classes、gt_overlaps、image_id、image_path、width、height等影像本身和目標檢測框標註的相關資訊。
**ratio_list:**表示根據影像的寬高比排序後的ratio_list。
**ratio_index:**表示根據影像寬高比排序後的list對應的原始影像索引image_index。
在combined_roidb()方法中,通過呼叫get_roidb()方法獲取影像相應的roidb資訊。
檔案:lib/datasets/imdb.py
呼叫append_flipped_images()方法對影像roidb中的box進行水平翻轉,注意此時僅對box進行翻轉,其中gt_overlaps和gt_classes與原始roidb相同,將dlipped翻轉標記設定為True表示翻轉後的影像。image_index * 2表示將圖片索引資訊複製一遍。
檔案:lib/roi_data_layer/roidb.py
prepare_roidb()方法對每張圖片的roidb進行資訊擴充,新增id,路徑,寬高,box類別資訊等,將所有影像的尺寸序列化儲存為pkl檔案,方便再次執行時讀取。
檔案:lib/roi_data_layer/roidb.py
filter_roidb()方法對每張影像的box數量進行檢查,在訓練階段剔除不含box目標檢測框的影像及其roidb。
檔案:lib/roi_data_layer/roidb.py
rank_roidb_ratio()方法檢查影像的寬高比,將寬高比大於2或者小於0.5的影像的裁剪標誌設定為True,並將寬高比更新為最大值或最小值,對訓練影像進行裁剪時需要。將ratio從小到大排序,返回排序索引及排序後的ratio_list。
4、構建DataLoader可讀的資料集
檔案:lib/roi_data_layer/roibatchLoader.py
檔案:lib/roi_data_layer/minibatch.py
定義roibatchLoader子類,該類繼承自Pytorch的data.Dataset類,重寫了__getitem__(self, index)函式和__len__(self)函式。
在該類的建構函式中,對排序後的影像寬高比根據batch_size的大小進行分段,保證每個batchsize中的影像具有相同大小的寬高比,對於faster rcnn訓練PASCAL VOC資料集來說,一個batchsize只有一幅影像。
在__getitem__函式中,根據索引取出對應的roidb,並根據取出的roidb呼叫get_minibatch()方法構建blobs,此處的blobs為dict型別,包含data、gt_boxes、im_info、img_id四個屬性。
data:表示影像本身畫素資訊,[1,w,h,c]結構,RGB->BGR,flipped為True的影像進行水平翻轉。
gt_boxes:n×5結構的矩陣,n表示一幅影像中box的個數,前4列表示表示box的左上右下角座標,最後一個列表示box的類別索引。(x1, y1, x2, y2, cls)
im_info:陣列型別,1×3陣列,分別儲存了影像的寬度、高度、尺度縮放比例。
img_id:等同於roidb中的img_id屬性,表示影像的索引id。
檔案:lib/roi_data_layer/roibatchLoader.py
根據返回的blobs資訊對需要進行裁剪的影像進行裁剪,如果影像的ratio大於2或者小於0.5則需要裁剪,裁剪後box的橫座標更新可能變為負值,或超出裁剪後的寬度範圍,將超出邊界的box做加緊處理,box的橫座標將變為邊界值。
根據影像的ratio值對寬度或者高度向上取整,對影像進行相應的padding。
檢查boundig box。
將data維度進行轉換,將通道數維度提前,(3, data_height, data_width)
返回值包括padding_data, im_info, gt_boxes_padding, num_boxes
padding_data:表示邊緣取整填充後的影像畫素值。
im_info:表示裁剪填充後影像的寬高及縮放尺度。
gt_boxes_padding:表示bounding box資訊。
num_boxes:表示bounding box的數量。
相關文章
- PyTorch 模型訓練實⽤教程(程式碼訓練步驟講解)PyTorch模型
- 【LLM訓練系列】NanoGPT原始碼詳解和中文GPT訓練實踐NaNGPT原始碼
- 目標檢測入門系列手冊四:Faster R-CNN 訓練教程ASTCNN
- 使用Pytorch訓練分類器詳解(附python演練)PyTorchPython
- Faster R-CNNASTCNN
- 程式碼實踐——Faster R-CNNASTCNN
- [原始碼解析] PyTorch 分散式之彈性訓練(2)---啟動&單節點流程原始碼PyTorch分散式
- 2、PyTorch訓練YOLOv11—訓練篇(detect)—Windows系統PyTorchYOLOv1Windows
- [原始碼解析] PyTorch 分散式之彈性訓練(3)---代理原始碼PyTorch分散式
- [專案實戰]訓練retinanet(pytorch版)NaNPyTorch
- 詳解Java 容器(第③篇)——容器原始碼分析 - ListJava原始碼
- 詳解Java 容器(第④篇)——容器原始碼分析 - MapJava原始碼
- shiro認證流程原始碼分析--練氣初期原始碼
- [原始碼解析] PyTorch 分散式之彈性訓練(5)---Rendezvous 引擎原始碼PyTorch分散式
- [原始碼分析] Facebook如何訓練超大模型---(4)原始碼大模型
- [原始碼分析] Facebook如何訓練超大模型--- (5)原始碼大模型
- [原始碼分析] Facebook如何訓練超大模型---(1)原始碼大模型
- [原始碼分析] Facebook如何訓練超大模型 --- (2)原始碼大模型
- [原始碼分析] Facebook如何訓練超大模型 --- (3)原始碼大模型
- SQLMAP原始碼分析Part1:流程篇SQL原始碼
- pytorch-模型儲存與載入自己訓練的模型詳解PyTorch模型
- [原始碼解析] PyTorch 分散式之彈性訓練(1) --- 總體思路原始碼PyTorch分散式
- SpringSecurity認證流程原始碼詳解SpringGse原始碼
- 【zookeeper原始碼】啟動流程詳解原始碼
- shiro登陸流程原始碼詳解原始碼
- LinkedList詳解-原始碼分析原始碼
- ArrayList詳解-原始碼分析原始碼
- pytorch指定GPU訓練PyTorchGPU
- Pytorch分散式訓練PyTorch分散式
- 探索 YOLO v3 原始碼 - 第1篇 訓練YOLO原始碼
- Faster R-CNN演算法解析ASTCNN演算法
- 詳解Java 容器(第⑤篇)——容器原始碼分析 - 併發容器Java原始碼
- [原始碼解析] PyTorch 分散式之彈性訓練(6)---監控/容錯原始碼PyTorch分散式
- [原始碼解析] PyTorch 分散式之彈性訓練(7)---節點變化原始碼PyTorch分散式
- ArrayMap詳解及原始碼分析原始碼
- LeakCanary詳解與原始碼分析原始碼
- EventBus詳解及原始碼分析原始碼
- MapReduce 詳解與原始碼分析原始碼