YOLOv3 的 TensorFlow 實現,GitHub 完整原始碼解析

红色石头發表於2019-01-31

來自華盛頓大學的 Joseph Redmon 和 Ali Farhadi 提出的YOLOv3 通過在 YOLO 中加入設計細節的變化,這個新模型在取得相當準確率的情況下實現了檢測速度的很大提升,一般它比 R-CNN 快 1000 倍、比 Fast R-CNN 快 100 倍。

這裡附上 YOLOv3 的論文地址:

https://pjreddie.com/media/files/papers/YOLOv3.pdf

本文的專案作者是 wizyoung,原 GitHub 專案地址為:

https://github.com/wizyoung/YOLOv3_TensorFlow

1. 介紹

本文將介紹 YOLO3 的完整 TensorFlow 實現。可在自己的資料集上進行完整的訓練和驗證操作,pipeline 完整。其特點包括:

  • 高效的 tf.data 管道
  • 權重轉換

  • GPU 提速,無限制

  • 完整的訓練管道

  • 使用 kMeans 演算法來選擇 anchor boxes

  • 多 GPU 同步訓練

2. 需求

  • tensorflow >= 1.8.0(不排除低版本也能工作)

  • opencv-python

3. 權重轉換

預訓練的 darknet 權重檔案可從下方連結下載:

https://pjreddie.com/media/files/yolov3.weights

把下載好後的檔案放在 ./data/darknet_weights/ 目錄下,執行下面的命令:

python convert_weight.py

然後,轉換後的 TensorFlow checkpoint 檔案將被儲存在 ./data/darknet_weights/ 目錄下。

4. 執行 demos

在 ./data/demo_data/ 目錄裡有一些影象和視訊的 demos 可以執行。

單個影象測試 demo:

python test_single_image.py ./data/demo_data/messi.jpg

視訊測試 demo:

python video_test.py ./data/demo_data/video.mp4

結果展示:

5. 執行速度

圖片尺寸為 416×416,論文實現與我的模型執行速度比較如下:

為什麼會這麼快呢?我們看一下論文中 ImageNet 分類情況:

6. 模型結構

為了更好地理解模型體系結構,可以參考下圖:

7. 訓練

首先是資料準備,分為三步。

1)annotation file

在 ./data/my_data/ 目錄下生成 train.txt/val.txt/test.txt 檔案。txt 檔案中一行表示一張圖片,形式為:圖片絕對路徑 + box_1 + box_2 + … + box_n。Box 的形式為:label_index + x_min + y_min + x_max + y_max,原始座標為圖片左上角。

例如:

xxx/xxx/1.jpg 0 453 369 473 391 1 588 245 608 268
xxx/xxx/2.jpg 1 466 403 485 422 2 793 300 809 320

注意:每個 txt 檔案最後一行為空白行。

2)class_names file

在 ./data/my_data/ 目錄下生成 data.names 檔案,每一行代表一個類別名稱。例如:

bird
person
bike

3)prior anchor file

使用 kMeans 演算法來選擇 anchor boxes:

python get_kmeans.py

然後,你將得到 9 個 anchors 和評價 IOU,把 anchors 儲存在 txt 檔案中。

準備完資料之後就可以開始訓練了。

使用 train.py 檔案,函式引數如下:

$ python train.py -h
usage: train.py [-h] [--train_file TRAIN_FILE] [--val_file VAL_FILE]
               [--restore_path RESTORE_PATH] 
               [--save_dir SAVE_DIR]
               [--log_dir LOG_DIR] 
               [--progress_log_path PROGRESS_LOG_PATH]
               [--anchor_path ANCHOR_PATH]
               [--class_name_path CLASS_NAME_PATH] [--batch_size BATCH_SIZE]
               [--img_size [IMG_SIZE [IMG_SIZE ...]]]
               [--total_epoches TOTAL_EPOCHES]
               [--train_evaluation_freq TRAIN_EVALUATION_FREQ]
               [--val_evaluation_freq VAL_EVALUATION_FREQ]
               [--save_freq SAVE_FREQ] [--num_threads NUM_THREADS]
               [--prefetech_buffer PREFETECH_BUFFER]
               [--optimizer_name OPTIMIZER_NAME]
               [--save_optimizer SAVE_OPTIMIZER]
               [--learning_rate_init LEARNING_RATE_INIT] [--lr_type LR_TYPE]
               [--lr_decay_freq LR_DECAY_FREQ]
               [--lr_decay_factor LR_DECAY_FACTOR]
               [--lr_lower_bound LR_LOWER_BOUND]
               [--restore_part [RESTORE_PART [RESTORE_PART ...]]]
               [--update_part [UPDATE_PART [UPDATE_PART ...]]]
               [--update_part [UPDATE_PART [UPDATE_PART ...]]]
               [--use_warm_up USE_WARM_UP] [--warm_up_lr WARM_UP_LR]
               [--warm_up_epoch WARM_UP_EPOCH]

8. 評價

使用 eval.py 來評估驗證集和測試集,函式引數如下:

$ python eval.py -h
usage: eval.py [-h] [--eval_file EVAL_FILE] [--restore_path RESTORE_PATH]
              [--anchor_path ANCHOR_PATH] 
              [--class_name_path CLASS_NAME_PATH]
              [--batch_size BATCH_SIZE]
              [--img_size [IMG_SIZE [IMG_SIZE ...]]]
              [--num_threads NUM_THREADS]
              [--prefetech_buffer PREFETECH_BUFFER]

函式返回 loss、召回率 recall、精準率 precision,如下所示:

recall: 0.927, precision: 0.945
total_loss: 0.210, loss_xy: 0.010, loss_wh: 0.025, loss_conf: 0.125, loss_class: 0.050

9. 其它技巧

訓練的時候可以嘗試使用下面這些技巧:

  • Data augmentation:使用 ./utils/data_utils.py 中的 data_augmentation 方法來增加資料。
  • 像 Gluon CV 一樣混合和 label 平滑。

  • 正則化技巧,例如 L2 正則化。

  • 多尺度訓練:你可以像原稿中的作者那樣定期改變輸入影象的尺度(即不同的輸入解析度)。

完整程式碼請見 GitHub:

https://github.com/wizyoung/YOLOv3_TensorFlow

參考文獻:

https://github.com/YunYang1994/tensorflow-yolov3

https://github.com/qqwweee/keras-yolo3

https://github.com/eriklindernoren/PyTorch-YOLOv3

https://github.com/pjreddie/darknet


相關文章