[專案實戰]訓練retinanet(pytorch版)

wuzeyuan發表於2019-05-08

採用github上star比較高的一個開源實現https://github.com/yhenon/pytorch-retinanet

在anaconda中新建了一個環境,因為一開始並沒有新建環境,在原有的環境裡,遇到了pytorch,numpy等版本問題,尤其是一開始用0.1.2版的pytorch,有一個TH路徑不對,導致編譯錯誤,這是我最討厭的錯誤,遇到編譯錯誤往往一臉懵逼,如果NMS部分不用編譯,直接用python實現就好了,當然那樣速度可能會慢很多.

先記錄下我的各個包的版本

cffi                      1.12.2

cudatoolkit               9.0

cudnn                     7.3.1

Cython                    0.29.7

matplotlib                3.0.3

numpy                     1.15.4

pytorch                   0.4.0

torchvision               0.2.1

當然了,其他版本也可以,但是這個版本一定是可行的.

然後準備訓練coco,首先需要下載coco,這裡採用wget下載,幾個壓縮包的地址連結https://blog.csdn.net/daniaokuye/article/details/78699138#commentsedit

採用wget下載意外地很慢,於是採用迅雷,意外地很快

網路開始訓練,沒有采用預訓練權重,既沒有用coco訓練好的,也沒有用resnet的預訓練權重(下載起來太慢了)

python train.py --dataset coco --coco_path ../coco --depth 50

訓練截圖

採用2個圖片作為一個batch訓練,GPU佔用

batchsize為2,訓練一個epoch大約6個小時,按照程式碼中預設的100個epoch,恐怕得600個小時,一個月了

幸好原始碼中提供了訓練好的coco權重,可以為我們所用,那就先看一下訓練好的效果,呼叫視覺化程式碼

python visualize.py --dataset coco --coco_path ../coco --model ./coco_resnet_50_map_0_335_state_dict.pt

效果

 

檢測效果還是ok的~

這個版本的實現程式碼量在2000行左右,很適宜閱讀,尤其是與Mask R-CNN(matterport版,大概6000行)相比

視覺化程式碼稍加改造,就可以作為一個目標檢測器使用了,棒!

相關文章