寫給程式設計師的機器學習入門 (十) - 物件識別 Faster-RCNN - 識別人臉位置與是否戴口罩

q303248153發表於2020-12-03

每次看到大資料人臉識別抓逃犯的新聞我都會感嘆技術發展的太快了,國家治安水平也越來越好了?。不過那種系統個人是沒辦法做出來的,今天我們只試著做個簡單的,怎麼根據圖片把沒有戴口罩的傢伙抓出來?。這篇會介紹實用性比較強的物件識別模型 Faster-RCNN,需要的基礎知識比較多,如果對機器學習和物件識別沒有基礎瞭解請看這個系列之前的文章

RCNN, Fast-RCNN 的弱點

我在上一篇文章介紹了物件識別使用的 RCNN, Fast-RCNN 模型,在這裡我簡單總結一下它們的缺點,Faster-RCNN 將會克服它們:

  • 選取區域使用的演算法是固定的,不參與學習
  • 選取區域的演算法本身消耗比較高 (搜尋選擇法)
  • 選取區域的演算法選出來的區域大部分都是重合的,並且只有很小一部分包含我們想要識別的物件
  • 區域範圍的精度比較低 (即使經過調整)
  • 判斷分類有時只能使用部分包含物件的區域 (例如選取區域的演算法給出左半張臉所在的區域,那麼就只能使用左半張臉判斷分類)

Faster-RCNN 概覽

Faster-RCNN 是 RCNN 和 Fast-RCNN 的進化版,最大的特徵是引入了區域生成網路 (RPN - Region Proposal Network),區域生成網路支援使用機器學習代替固定的演算法找出圖片中可能包含物件的區域,精度比固定的演算法要高很多,而且速度也變快了。

Faster-RCNN 的結構如下圖所示,分成了兩大部分,第一部分是區域生成網路,首先會把圖片劃分為多個小區域 (大小依賴於圖片大小和 CNN 網路結構,詳細會在後面說明),每個小區域都對應一個錨點 (Anchor),區域生成網路會判斷錨點所在的區域是否包含物件,與包含的物件的形狀 (例如只包含鼻子,就大約可以估計周圍的幾個區域是臉);第二部分是標籤分類網路,與上一篇文章介紹的 Fast-RCNN 基本上相同,會根據區域生成網路的輸出擷取特徵,並根據特徵判斷屬於什麼分類。

因為區域生成網路可以參與學習,我們可以定製一個只識別某幾種物件的網路,例如圖片中有人,狗,車,樹,房子的時候,固定的演算法可能會把他們全部提取出來,但區域生成網路經過訓練可以只提取人所在的區域,其他物件所在的區域都會當作背景處理,這樣區域生成網路輸出的區域數量將會少很多,而且包含物件的可能性會很高。

Faster-RCNN 另一個比較強大的特徵是會分兩步來識別區域是否包含物件與調整區域範圍,第一步在區域生成網路,第二步在標籤分類網路。舉一個通俗的例子,如果區域生成網路選取了某個包含了臉的左半部分的區域,它會判斷這個區域可能包含物件,並且要求區域範圍向右擴大一些,接下來標籤分類網路會擷取範圍擴大之後的區域,這個區域會同時包含臉的左半部分和右半部分,也就是擷取出來的特徵會包含更多的資訊,這時標籤分類網路可以使用特徵進一步判斷這張臉所屬的分類,如果範圍擴大以後發現這不是一張臉而是別的什麼東西那麼區域分類網路會輸出 "非物件" 的分類排除這個區域,如果判斷是臉那麼標籤分類網路會進一步的調整區域範圍,使得範圍更精準。而 Fast-RCNN 遇到同樣的情況只能根據臉的左半部分對應的特徵判斷分類,資訊量不足可能會導致結果不準確。這種做法使得 Faster-RCNN 的識別精度相對於之前的模型提升了很多。

接下來看看 Faster-RCNN 的實現細節吧,部分內容有一定難度?,如果覺得難以理解可以先跳過去後面再參考程式碼實現。

Faster-RCNN 的原始論文在這裡,有興趣的可以看看?。

Faster-RCNN 的實現

這篇給出的程式碼會使用 Pillow 類庫實現,代替之前的 opencv,所以部分處理相同的步驟也會給出新的程式碼例子。

縮放來源圖片

和 Fast-RCNN 一樣,Faster-RCNN 也會使用 CNN 模型針對整張圖片生成各個區域的特徵,所以我們需要縮放原圖片。(儘管 CNN 模型支援非固定大小的來源,但統一大小可以讓後續的處理更簡單,並且也可以批量處理大小不一樣的圖片。)

這篇文章會使用 Pillow 代替 opencv,縮放圖片的程式碼如下所示:

# 縮放圖片的大小
IMAGE_SIZE = (256, 192)

def calc_resize_parameters(sw, sh):
    """計算縮放圖片的引數"""
    sw_new, sh_new = sw, sh
    dw, dh = IMAGE_SIZE
    pad_w, pad_h = 0, 0
    if sw / sh < dw / dh:
        sw_new = int(dw / dh * sh)
        pad_w = (sw_new - sw) // 2 # 填充左右
    else:
        sh_new = int(dh / dw * sw)
        pad_h = (sh_new - sh) // 2 # 填充上下
    return sw_new, sh_new, pad_w, pad_h

def resize_image(img):
    """縮放圖片,比例不一致時填充"""
    sw, sh = img.size
    sw_new, sh_new, pad_w, pad_h = calc_resize_parameters(sw, sh)
    img_new = Image.new("RGB", (sw_new, sh_new))
    img_new.paste(img, (pad_w, pad_h))
    img_new = img_new.resize(IMAGE_SIZE)
    return img_new

計算區域特徵

與 Fast-RCNN 一樣,Faster-RCNN 計算區域特徵的時候也會使用除去全連線層的 CNN 模型,例如 Resnet-18 模型在原圖大小為 3,256,256 的時候 (3 代表 RGB 三通道)會輸出 512,32,32 的矩陣,通道數量變多,長寬變為原有的 1/8,也就是每個 8x8 的區域經過處理以後都對應 512 個特徵,如下圖所示:

對 CNN 模型不熟悉的可以複習這個系列的第八篇文章,詳細介紹了 Resnet-18 的結構與計算流程。

上一篇文章的 Fast-RCNN 例子改動了 Resnet 模型使得輸出的特徵矩陣長寬與原圖相同,以方便後面提取特徵 (ROI Pooling) 的處理,這篇將不需要這麼做,這篇使用的模型會輸出長寬為原有的 1/8 的特徵矩陣,但為了適應視訊記憶體比較低的機器會減少輸出的通道數量,具體請參考後面的實現程式碼。

定義錨點 (Anchor)

Faster-RCNN 的區域生成網路會基於錨點 (Anchor) 判斷某個區域是否包含物件,與物件相對於錨點的形狀。錨點對應的區域大小其實就是上面特徵矩陣中每個點對應的區域大小,如下圖所示:

上面的例子中應該有 32x32 個錨點,每個錨點對應 512,1,1 的值。

之後各個錨點對應的值會交給線性模型,判斷錨點所在的區域是否包含物件,如下圖所示 (為了簡化這張圖用了 4x4 個錨點,紅色的錨點代表包含物件):

當然的,錨點所在的區域與物件實際所在的區域範圍並不會完全一樣,錨點所在的區域可能只包含物件的左半部分,右半部分,或者中心部分,物件可能比錨點所在區域大很多,也可能比錨點所在區域小,只判斷錨點所在的區域是否包含物件並不夠準確。

為了解決這個問題,Faster-RCNN 的區域生成網路為每個錨點定義了幾個固定的形狀,形狀有兩個引數,一個是大小比例,一個是長寬比例,如下圖所示,對比上面的實際區域可以發現形狀 6 和形狀 7 的重疊率 (IOU) 是比較高的:

之後區域生成網路的線性模型可以分別判斷各個形狀是否包含物件:

再輸出各個形狀對應的範圍調整值,即可給出可能包含物件的區域。在上述的例子中,如果區域生成網路學習得當,形狀 6 和形狀 7 經過區域範圍調整以後應該會輸出很接近的區域。

需要注意的是,雖然錨點支援判斷比自己對應的區域更大的範圍是否包含物件,但判斷的依據只來源於自己對應的區域。舉例來說如果錨點對應的區域只包含鼻子,那麼它可以判斷形狀 7 可能包含物件,之後再交給標籤分類網路作進一步判斷。如果擴大以後發現其實不是人臉,而是別的什麼東西,那麼標籤分類網路將會輸出 "非物件" 標籤來排除這個區域,如前文介紹的一樣。

生成錨點的程式碼如下,每個錨點會對應 7 * 3 = 21 個形狀,span 代表 原圖片長寬 / CNN 模型輸出長寬

# 縮放圖片的大小
IMAGE_SIZE = (256, 192)
# 錨點對應區域的縮放比例列表
AnchorScales = (0.5, 1, 2, 3, 4, 5, 6)
# 錨點對應區域的長寬比例列表
AnchorAspects = ((1, 2), (1, 1), (2, 1))

def generate_anchors(span):
    """根據錨點和形狀生成錨點範圍列表"""
    w, h = IMAGE_SIZE
    anchors = []
    for x in range(0, w, span):
        for y in range(0, h, span):
            xcenter, ycenter = x + span / 2, y + span / 2
            for scale in AnchorScales:
                for ratio in AnchorAspects:
                    ww = span * scale * ratio[0]
                    hh = span * scale * ratio[1]
                    xx = xcenter - ww / 2
                    yy = ycenter - hh / 2
                    xx = max(int(xx), 0)
                    yy = max(int(yy), 0)
                    ww = min(int(ww), w - xx)
                    hh = min(int(hh), h - yy)
                    anchors.append((xx, yy, ww, hh))
    return anchors

Anchors = generate_anchors(8)

區域生成網路 (RPN)

看完上一段關於錨點的定義你應該對區域生成網路的工作方式有個大概的印象,這裡我再給出區域生成網路的具體實現架構,這個架構跟後面的程式碼例子相同。

區域生成網路的處理本身應該不需要多解釋了?,如果覺得難以理解請重新閱讀這一篇前面的部分和上一篇文章,特別是上一篇文章的以下部分:

  • 按重疊率 (IOU) 判斷每個區域是否包含物件
  • 調整區域範圍

計算區域範圍偏移的損失這裡使用了 Smooth L1 (上一篇是 MSELoss),具體的計算方法會在後面計算損失的部分介紹。

區域生成網路最終會輸出不定數量的可能包含物件的區域,接下來就是提取這些區域對應的特徵了,注意區域生成網路使用的特徵和標籤分類網路使用的特徵需要分開,很多文章或者實現介紹 Faster-RCNN 的時候都讓兩個網路使用相同的特徵,但經過我實測使用相同的特徵會在調整引數的時候發生干擾導致無法學習,與上一篇文章正負樣本的損失需要分開計算的原因一樣。部分實現的確使用了相同的特徵,但這些實現調整引數使用的 backward 是自己手寫的,可能這裡有什麼祕密吧?。

從區域提取特徵 - 仿射變換 (ROI Pooling - Affine Transformation)

上一篇介紹的 Fast-RCNN 在生成特徵的時候讓長寬與原圖片相同,所以 ROI 層提取特徵只需要使用 [] 操作符,但這一篇生成特徵的時候長寬變為了原來的 1/8,那麼需要怎樣提取特徵呢?

最簡單的方法是把座標和長寬除以 8 再使用 [] 操作符提取,然後使用 AdaptiveMaxPool 縮放到固定的大小。但這裡我要介紹一個更高階的方法,即仿射變換 (Affine Transformation),使用仿射變換可以非常高效的對圖片進行批量擷取、縮放與旋轉等操作。

仿射變換的原理是給原圖片和輸出圖片之間的畫素座標建立對應關係,一共有 6 個引數,其中 4 個引數用於給座標做矩陣乘法 (支援縮放與旋轉等變形操作),2 個引數用於做完矩陣乘法以後相加 (支援平移等操作),計算公式如下:

需要注意的是,仿射變換裡面不會直接計算座標的絕對值,而是把圖片的左上角當作 (-1, -1),右下角當作 (1, 1) 然後轉換座標到這個尺度裡面,再進行計算。

舉例來說,如果想把原圖片的中心部分放大兩倍到輸出圖片,可以使用以下引數:

0.5,   0, 0
  0, 0.5, 0

效果如下,如果你拿輸出圖片的四個角的座標結合上面的引數計算,可以得出原圖中心部分的範圍:

更多例子可以參考這篇文章,對理解仿射變換非常有幫助。

那麼從區域提取特徵的時候,應該使用怎樣的引數呢?計算引數的公式推導過程如下?:

使用 pytorch 實現如下,注意 pytorch 的仿射變換要求資料維度是 (C, H, W),而我們使用的資料維度是 (C, W, H),所以需要調換引數的位置,pooling_size 代表輸出圖片的大小,這樣仿射變換不僅可以擷取範圍還能幫我們縮放到指定的大小:

# 縮放圖片的大小
IMAGE_SIZE = (256, 192)

def roi_crop(features, rois, pooling_size):
    """根據區域擷取特徵,每次只能處理單張圖片"""
    width, height = IMAGE_SIZE
    theta = []
    results = []
    for roi in rois:
        x1, y1, w, h = roi
        x2, y2 = x1 + w, y1 + h
        theta = [[
            [
                (y2 - y1) / height,
                0,
                (y2 + y1) / height - 1
            ],
            [
                0,
                (x2 - x1) / width,
                (x2 + x1) / width - 1
            ]
        ]]
        theta_tensor = torch.tensor(theta)
        grid = nn.functional.affine_grid(
            theta_tensor,
            torch.Size((1, 1, pooling_size, pooling_size)),
            align_corners=False).to(device)
        result = nn.functional.grid_sample(
            features.unsqueeze(0), grid, align_corners=False)
        results.append(result)
    if not results:
        return None
    results = torch.cat(results, dim=0)
    return results

如果 pooling_size 為 7,那麼 results 的維度就是 範圍的數量, 7, 7

仿射變換本來是用在 STN 網路裡的,用於把旋轉變形以後的圖片還原,如果你有興趣可以參考這裡

根據特徵識別分類

接下來就是根據特徵識別分類了?,處理上與之前的 Fast-RCNN 基本上相同,除了 Faster-RCNN 在生成範圍調整引數的時候會針對每個分類分別生成,如果有 5 個分類,那麼就會有 5 * 4 = 20 個輸出,這會讓範圍調整變得更準確。

標籤分類網路的具體實現架構如下,最終會輸出包含物件的範圍與各個範圍對應的分類,整個 Faster-RCNN 的處理就到此為止了?。

有一點需要注意的是,標籤分類網路使用的分類需要額外包含一個 "非物件" 分類,例如原有分類列表為 [戴口罩人臉,不戴口罩人臉] 時,實際判斷分類列表應該為 [非人臉, 戴口罩人臉,不戴口罩人臉]。這是因為標籤分類網路的特徵擷取範圍比區域生成網路要大,範圍也更準確,標籤範圍網路可以根據更準確的特徵來排除那些區域生成網路以為是物件但實際不是物件的範圍。

計算損失

到此為止我們看到了以下的損失:

  • 區域生成網路判斷是否物件的損失
  • 區域生成網路的範圍調整引數的損失 (僅針對是物件的範圍計算)
  • 標籤分類網路判斷物件所屬分類的損失
  • 標籤分類網路的範圍調整引數的損失 (僅針對是物件,並且可能性最大的分類計算)

這些損失可以通過 + 合併,然後再通過 backward 反饋到各個網路的 CNN 模型與線性模型。需要注意的是,在批量訓練的時候因為各個圖片的輸出範圍數量不一樣,上面的損失會先根據各張圖片計算後再平均。你可能記得上一篇 Fast-RCNN 計算損失的時候需要根據正負樣本分別計算,這一篇不需要,Faster-RCNN 的區域生成網路輸出的範圍比較準確,很少會出現來源特徵相同但同時輸出 "是物件" 和 "非物件" 結果的情況。此外,如前文所提到的,區域生成網路與標籤分類網路應該使用不同的 CNN 模型生成不同的特徵,以避免通過損失調整模型引數時發生干擾。

計算範圍調整損失的時候用的是 Smooth L1 函式,這個函式我們之前沒有看到過,所以我再簡單介紹一下它的計算方法:

簡單的來說就是如果預測輸出和實際輸出之間的差距比較小的時候,反過來增加損失使得調整速度更快,因為區域範圍偏移需要讓預測輸出在數值上更接近實際輸出 (不像標籤分類可以只調整方向不管具體值),使用 Smooth L1 調整起來效果會更好。

合併結果區域

Faster-RCNN 可能會針對同一個物件輸出多個重合的範圍,但因為 Faster-RCNN 的精確度比較高,這些重合的範圍的重疊率應該也比較高,我們可以結合這些範圍得出結果範圍:

好了,對 Faster-RCNN 的介紹就到此為止了?,接下來我們看看程式碼實現吧。

使用 Faster-RCNN 識別人臉位置與是否戴口罩

這次的任務是識別圖片中人臉的位置,與判斷是否有正確佩戴口罩,一共有以下的分類:

  • 非人臉: other
  • 戴口罩: with_mask
  • 沒戴口罩: without_mask
  • 戴了口罩但姿勢不正確: mask_weared_incorrect

訓練使用的資料也是來源於 kaggle,下載需要註冊帳號但不用給錢:

https://www.kaggle.com/andrewmvd/face-mask-detection

例如下面這張圖片:

對應以下的標記 (xml 格式):

<annotation>
    <folder>images</folder>
    <filename>maksssksksss0.png</filename>
    <size>
        <width>512</width>
        <height>366</height>
        <depth>3</depth>
    </size>
    <segmented>0</segmented>
    <object>
        <name>without_mask</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <occluded>0</occluded>
        <difficult>0</difficult>
        <bndbox>
            <xmin>79</xmin>
            <ymin>105</ymin>
            <xmax>109</xmax>
            <ymax>142</ymax>
        </bndbox>
    </object>
    <object>
        <name>with_mask</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <occluded>0</occluded>
        <difficult>0</difficult>
        <bndbox>
            <xmin>185</xmin>
            <ymin>100</ymin>
            <xmax>226</xmax>
            <ymax>144</ymax>
        </bndbox>
    </object>
    <object>
        <name>without_mask</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <occluded>0</occluded>
        <difficult>0</difficult>
        <bndbox>
            <xmin>325</xmin>
            <ymin>90</ymin>
            <xmax>360</xmax>
            <ymax>141</ymax>
        </bndbox>
    </object>
</annotation>

使用 Faster-RCNN 訓練與識別的程式碼如下?:

import os
import sys
import torch
import gzip
import itertools
import random
import numpy
import math
import pandas
import json
from PIL import Image
from PIL import ImageDraw
from torch import nn
from matplotlib import pyplot
from collections import defaultdict
import xml.etree.cElementTree as ET
from collections import Counter

# 縮放圖片的大小
IMAGE_SIZE = (256, 192)
# 分析目標的圖片所在的資料夾
IMAGE_DIR = "./archive/images"
# 定義各個圖片中人臉區域與分類的 CSV 檔案
ANNOTATION_DIR = "./archive/annotations"
# 分類列表
CLASSES = [ "other", "with_mask", "without_mask", "mask_weared_incorrect" ]
CLASSES_MAPPING = { c: index for index, c in enumerate(CLASSES) }
# 判斷是否存在物件使用的區域重疊率的閾值
IOU_POSITIVE_THRESHOLD = 0.35
IOU_NEGATIVE_THRESHOLD = 0.10
# 判斷是否應該合併重疊區域的重疊率閾值
IOU_MERGE_THRESHOLD = 0.35

# 用於啟用 GPU 支援
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class BasicBlock(nn.Module):
    """ResNet 使用的基礎塊"""
    expansion = 1 # 定義這個塊的實際出通道是 channels_out 的幾倍,這裡的實現固定是一倍
    def __init__(self, channels_in, channels_out, stride):
        super().__init__()
        # 生成 3x3 的卷積層
        # 處理間隔 stride = 1 時,輸出的長寬會等於輸入的長寬,例如 (32-3+2)//1+1 == 32
        # 處理間隔 stride = 2 時,輸出的長寬會等於輸入的長寬的一半,例如 (32-3+2)//2+1 == 16
        # 此外 resnet 的 3x3 卷積層不使用偏移值 bias
        self.conv1 = nn.Sequential(
            nn.Conv2d(channels_in, channels_out, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(channels_out))
        # 再定義一個讓輸出和輸入維度相同的 3x3 卷積層
        self.conv2 = nn.Sequential(
            nn.Conv2d(channels_out, channels_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(channels_out))
        # 讓原始輸入和輸出相加的時候,需要維度一致,如果維度不一致則需要整合
        self.identity = nn.Sequential()
        if stride != 1 or channels_in != channels_out * self.expansion:
            self.identity = nn.Sequential(
                nn.Conv2d(channels_in, channels_out * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channels_out * self.expansion))

    def forward(self, x):
        # x => conv1 => relu => conv2 => + => relu
        # |                              ^
        # |==============================|
        tmp = self.conv1(x)
        tmp = nn.functional.relu(tmp, inplace=True)
        tmp = self.conv2(tmp)
        tmp += self.identity(x)
        y = nn.functional.relu(tmp, inplace=True)
        return y

class MyModel(nn.Module):
    """Faster-RCNN (基於 ResNet-18 的變種)"""
    Anchors = None # 錨點列表,包含 錨點數量 * 形狀數量 的範圍
    AnchorSpan = 8 # 錨點之間的距離,應該等於原有長寬 / resnet 輸出長寬
    AnchorScales = (0.5, 1, 2, 3, 4, 5, 6) # 錨點對應區域的縮放比例列表
    AnchorAspects = ((1, 2), (1, 1), (2, 1)) # 錨點對應區域的長寬比例列表
    AnchorBoxes = len(AnchorScales) * len(AnchorAspects) # 每個錨點對應的形狀數量

    def __init__(self):
        super().__init__()
        # 抽取圖片各個區域特徵的 ResNet (除去 AvgPool 和全連線層)
        # 和 Fast-RCNN 例子不同的是輸出的長寬會是原有的 1/8,後面會根據錨點與 affine_grid 擷取區域
        # 此外,為了可以讓模型跑在 4GB 視訊記憶體上,這裡減少了模型的通道數量
        # 注意:
        # RPN 使用的模型和標籤分類使用的模型需要分開,否則會出現無法學習 (RPN 總是輸出負) 的問題
        self.previous_channels_out = 4
        self.rpn_resnet = nn.Sequential(
            nn.Conv2d(3, self.previous_channels_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(self.previous_channels_out),
            nn.ReLU(inplace=True),
            self._make_layer(BasicBlock, channels_out=16, num_blocks=2, stride=1),
            self._make_layer(BasicBlock, channels_out=32, num_blocks=2, stride=2),
            self._make_layer(BasicBlock, channels_out=64, num_blocks=2, stride=2),
            self._make_layer(BasicBlock, channels_out=128, num_blocks=2, stride=2))
        self.previous_channels_out = 4
        self.cls_resnet = nn.Sequential(
            nn.Conv2d(3, self.previous_channels_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(self.previous_channels_out),
            nn.ReLU(inplace=True),
            self._make_layer(BasicBlock, channels_out=16, num_blocks=2, stride=1),
            self._make_layer(BasicBlock, channels_out=32, num_blocks=2, stride=2),
            self._make_layer(BasicBlock, channels_out=64, num_blocks=2, stride=2),
            self._make_layer(BasicBlock, channels_out=128, num_blocks=2, stride=2))
        self.features_channels = 128
        # 根據區域特徵生成各個錨點對應的物件可能性的模型
        self.rpn_labels_model = nn.Sequential(
            nn.Linear(self.features_channels, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, MyModel.AnchorBoxes*2))
        # 根據區域特徵生成各個錨點對應的區域偏移的模型
        self.rpn_offsets_model = nn.Sequential(
            nn.Linear(self.features_channels, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, MyModel.AnchorBoxes*4))
        # 選取可能出現物件的區域需要的最小可能性
        self.rpn_score_threshold = 0.9
        # 每張圖片最多選取的區域列表
        self.rpn_max_candidates = 32
        # 根據區域擷取特徵後縮放到的大小
        self.pooling_size = 7
        # 根據區域特徵判斷分類的模型
        self.cls_labels_model = nn.Sequential(
            nn.Linear(self.features_channels * (self.pooling_size ** 2), 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, len(CLASSES)))
        # 根據區域特徵再次生成區域偏移的模型,注意區域偏移會針對各個分類分別生成
        self.cls_offsets_model = nn.Sequential(
            nn.Linear(self.features_channels * (self.pooling_size ** 2), 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, len(CLASSES)*4))

    def _make_layer(self, block_type, channels_out, num_blocks, stride):
        """建立 resnet 使用的層"""
        blocks = []
        # 新增第一個塊
        blocks.append(block_type(self.previous_channels_out, channels_out, stride))
        self.previous_channels_out = channels_out * block_type.expansion
        # 新增剩餘的塊,剩餘的塊固定處理間隔為 1,不會改變長寬
        for _ in range(num_blocks-1):
            blocks.append(block_type(self.previous_channels_out, self.previous_channels_out, 1))
            self.previous_channels_out *= block_type.expansion
        return nn.Sequential(*blocks)

    @staticmethod
    def _generate_anchors(span):
        """根據錨點和形狀生成錨點範圍列表"""
        w, h = IMAGE_SIZE
        anchors = []
        for x in range(0, w, span):
            for y in range(0, h, span):
                xcenter, ycenter = x + span / 2, y + span / 2
                for scale in MyModel.AnchorScales:
                    for ratio in MyModel.AnchorAspects:
                        ww = span * scale * ratio[0]
                        hh = span * scale * ratio[1]
                        xx = xcenter - ww / 2
                        yy = ycenter - hh / 2
                        xx = max(int(xx), 0)
                        yy = max(int(yy), 0)
                        ww = min(int(ww), w - xx)
                        hh = min(int(hh), h - yy)
                        anchors.append((xx, yy, ww, hh))
        return anchors

    @staticmethod
    def _roi_crop(features, rois, pooling_size):
        """根據區域擷取特徵,每次只能處理單張圖片"""
        width, height = IMAGE_SIZE
        theta = []
        results = []
        for roi in rois:
            x1, y1, w, h = roi
            x2, y2 = x1 + w, y1 + h
            theta = [[
                [
                    (y2 - y1) / height,
                    0,
                    (y2 + y1) / height - 1
                ],
                [
                    0,
                    (x2 - x1) / width,
                    (x2 + x1) / width - 1
                ]
            ]]
            theta_tensor = torch.tensor(theta)
            grid = nn.functional.affine_grid(
                theta_tensor,
                torch.Size((1, 1, pooling_size, pooling_size)),
                align_corners=False).to(device)
            result = nn.functional.grid_sample(
                features.unsqueeze(0), grid, align_corners=False)
            results.append(result)
        if not results:
            return None
        results = torch.cat(results, dim=0)
        return results

    def forward(self, x):
        # ***** 抽取特徵部分 *****
        # 分別抽取 RPN 和標籤分類使用的特徵
        # 維度是 B,128,W/8,H/8
        rpn_features_original = self.rpn_resnet(x)
        # 維度是 B*W/8*H/8,128 (把通道放在最後,用於傳給線性模型)
        rpn_features = rpn_features_original.permute(0, 2, 3, 1).reshape(-1, self.features_channels)
        # 維度是 B,128,W/8,H/8
        cls_features = self.cls_resnet(x)

        # ***** 選取區域部分 *****
        # 根據區域特徵生成各個錨點對應的物件可能性
        # 維度是 B,W/8*H/8*AnchorBoxes,2
        rpn_labels = self.rpn_labels_model(rpn_features)
        rpn_labels = rpn_labels.reshape(
            rpn_features_original.shape[0],
            rpn_features_original.shape[2] * rpn_features_original.shape[3] * MyModel.AnchorBoxes,
            2)
        # 根據區域特徵生成各個錨點對應的區域偏移
        # 維度是 B,W/8*H/8*AnchorBoxes,4
        rpn_offsets = self.rpn_offsets_model(rpn_features)
        rpn_offsets = rpn_offsets.reshape(
            rpn_features_original.shape[0],
            rpn_features_original.shape[2] * rpn_features_original.shape[3] * MyModel.AnchorBoxes,
            4)
        # 選取可能出現物件的區域,並調整區域範圍
        with torch.no_grad():
            rpn_scores = nn.functional.softmax(rpn_labels, dim=2)[:,:,1]
            # 選取可能性最高的部分割槽域
            rpn_top_scores = torch.topk(rpn_scores, k=self.rpn_max_candidates, dim=1)
            rpn_candidates_batch = []
            for x in range(0, rpn_scores.shape[0]):
                rpn_candidates = []
                for score, index in zip(rpn_top_scores.values[x], rpn_top_scores.indices[x]):
                    # 過濾可能性低於指定閾值的區域
                    if score.item() < self.rpn_score_threshold:
                        continue
                    anchor_box = MyModel.Anchors[index.item()]
                    offset = rpn_offsets[x,index.item()].tolist()
                    # 調整區域範圍
                    candidate_box = adjust_box_by_offset(anchor_box, offset)
                    rpn_candidates.append(candidate_box)
                rpn_candidates_batch.append(rpn_candidates)

        # ***** 判斷分類部分 *****
        cls_output = []
        cls_result = []
        for index in range(0, cls_features.shape[0]):
            pooled = MyModel._roi_crop(
                cls_features[index], rpn_candidates_batch[index], self.pooling_size)
            if pooled is None:
                # 沒有找到可能包含物件的區域
                cls_output.append(None)
                cls_result.append(None)
                continue
            pooled = pooled.reshape(pooled.shape[0], -1)
            labels = self.cls_labels_model(pooled)
            offsets = self.cls_offsets_model(pooled)
            cls_output.append((labels, offsets))
            # 使用 softmax 判斷可能性最大的分類
            classes = nn.functional.softmax(labels, dim=1).max(dim=1).indices
            # 根據分類對應的偏移再次調整區域範圍
            offsets_map = offsets.reshape(offsets.shape[0] * len(CLASSES), 4)
            result = []
            for box_index in range(0, classes.shape[0]):
                predicted_label = classes[box_index].item()
                if predicted_label == 0:
                    continue # 0 代表 other, 表示非物件
                candidate_box = rpn_candidates_batch[index][box_index]
                offset = offsets_map[box_index * len(CLASSES) + predicted_label].tolist()
                predicted_box = adjust_box_by_offset(candidate_box, offset)
                # 新增分類與最終預測區域
                result.append((predicted_label, predicted_box))
            cls_result.append(result)

        # 前面的專案用於學習,最後一項是最終輸出結果
        return rpn_labels, rpn_offsets, rpn_candidates_batch, cls_output, cls_result

    @staticmethod
    def loss_function(predicted, actual):
        """Faster-RCNN 使用的多工損失計算器"""
        rpn_labels, rpn_offsets, rpn_candidates_batch, cls_output, _ = predicted
        rpn_labels_losses = []
        rpn_offsets_losses = []
        cls_labels_losses = []
        cls_offsets_losses = []
        for batch_index in range(len(actual)):
            # 計算 RPN 的損失
            (true_boxes_labels,
                actual_rpn_labels, actual_rpn_labels_mask,
                actual_rpn_offsets, actual_rpn_offsets_mask) = actual[batch_index]
            rpn_labels_losses.append(nn.functional.cross_entropy(
                rpn_labels[batch_index][actual_rpn_labels_mask],
                actual_rpn_labels.to(device)))
            rpn_offsets_losses.append(nn.functional.smooth_l1_loss(
                rpn_offsets[batch_index][actual_rpn_offsets_mask],
                actual_rpn_offsets.to(device)))
            # 計算標籤分類的損失
            if cls_output[batch_index] is None:
                continue
            cls_labels_mask = []
            cls_offsets_mask = []
            cls_actual_labels = []
            cls_actual_offsets = []
            cls_predicted_labels, cls_predicted_offsets = cls_output[batch_index]
            cls_predicted_offsets_map = cls_predicted_offsets.reshape(-1, 4)
            rpn_candidates = rpn_candidates_batch[batch_index]
            for box_index, candidate_box in enumerate(rpn_candidates):
                iou_list = [ calc_iou(candidate_box, true_box) for (_, true_box) in true_boxes_labels ]
                positive_index = next((index for index, iou in enumerate(iou_list) if iou > IOU_POSITIVE_THRESHOLD), None)
                is_negative = all(iou < IOU_NEGATIVE_THRESHOLD for iou in iou_list)
                if positive_index is not None:
                    true_label, true_box = true_boxes_labels[positive_index]
                    cls_actual_labels.append(true_label)
                    cls_labels_mask.append(box_index)
                    # 如果區域正確,則學習真實分類對應的區域偏移
                    cls_actual_offsets.append(calc_box_offset(candidate_box, true_box))
                    cls_offsets_mask.append(box_index * len(CLASSES) + true_label)
                elif is_negative:
                    cls_actual_labels.append(0) # 0 代表 other, 表示非物件
                    cls_labels_mask.append(box_index)
                # 如果候選區域與真實區域的重疊率介於兩個閾值之間,則不參與學習
            if cls_labels_mask:
                cls_labels_losses.append(nn.functional.cross_entropy(
                    cls_predicted_labels[cls_labels_mask],
                    torch.tensor(cls_actual_labels).to(device)))
            if cls_offsets_mask:
                cls_offsets_losses.append(nn.functional.smooth_l1_loss(
                    cls_predicted_offsets_map[cls_offsets_mask],
                    torch.tensor(cls_actual_offsets).to(device)))
        # 合併損失值
        # 注意 loss 不可以使用 += 合併
        loss = torch.tensor(.0, requires_grad=True)
        loss = loss + torch.mean(torch.stack(rpn_labels_losses))
        loss = loss + torch.mean(torch.stack(rpn_offsets_losses))
        if cls_labels_losses:
            loss = loss + torch.mean(torch.stack(cls_labels_losses))
        if cls_offsets_losses:
            loss = loss + torch.mean(torch.stack(cls_offsets_losses))
        return loss

    @staticmethod
    def calc_accuracy(actual, predicted):
        """Faster-RCNN 使用的正確率計算器,這裡只計算 RPN 與標籤分類的正確率,區域偏移不計算"""
        rpn_labels, rpn_offsets, rpn_candidates_batch, cls_output, cls_result = predicted
        rpn_acc = 0
        cls_acc = 0
        for batch_index in range(len(actual)):
            # 計算 RPN 的正確率,正樣本和負樣本的正確率分別計算再平均
            (true_boxes_labels,
                actual_rpn_labels, actual_rpn_labels_mask,
                actual_rpn_offsets, actual_rpn_offsets_mask) = actual[batch_index]
            a = actual_rpn_labels.to(device)
            p = torch.max(rpn_labels[batch_index][actual_rpn_labels_mask], 1).indices
            rpn_acc_positive = ((a == 0) & (p == 0)).sum().item() / ((a == 0).sum().item() + 0.00001)
            rpn_acc_negative = ((a == 1) & (p == 1)).sum().item() / ((a == 1).sum().item() + 0.00001)
            rpn_acc += (rpn_acc_positive + rpn_acc_negative) / 2
            # 計算標籤分類的正確率
            # 正確率 = 有對應預測區域並且預測分類正確的真實區域數量 / 總真實區域數量
            cls_correct = 0
            for true_label, true_box in true_boxes_labels:
                if cls_result[batch_index] is None:
                    continue
                for predicted_label, predicted_box in cls_result[batch_index]:
                    if calc_iou(predicted_box, true_box) > IOU_POSITIVE_THRESHOLD and predicted_label == true_label:
                        cls_correct += 1
                        break
            cls_acc += cls_correct / len(true_boxes_labels)
        rpn_acc /= len(actual)
        cls_acc /= len(actual)
        return rpn_acc, cls_acc

MyModel.Anchors = MyModel._generate_anchors(8)

def save_tensor(tensor, path):
    """儲存 tensor 物件到檔案"""
    torch.save(tensor, gzip.GzipFile(path, "wb"))

def load_tensor(path):
    """從檔案讀取 tensor 物件"""
    return torch.load(gzip.GzipFile(path, "rb"))

def calc_resize_parameters(sw, sh):
    """計算縮放圖片的引數"""
    sw_new, sh_new = sw, sh
    dw, dh = IMAGE_SIZE
    pad_w, pad_h = 0, 0
    if sw / sh < dw / dh:
        sw_new = int(dw / dh * sh)
        pad_w = (sw_new - sw) // 2 # 填充左右
    else:
        sh_new = int(dh / dw * sw)
        pad_h = (sh_new - sh) // 2 # 填充上下
    return sw_new, sh_new, pad_w, pad_h

def resize_image(img):
    """縮放圖片,比例不一致時填充"""
    sw, sh = img.size
    sw_new, sh_new, pad_w, pad_h = calc_resize_parameters(sw, sh)
    img_new = Image.new("RGB", (sw_new, sh_new))
    img_new.paste(img, (pad_w, pad_h))
    img_new = img_new.resize(IMAGE_SIZE)
    return img_new

def image_to_tensor(img):
    """轉換圖片物件到 tensor 物件"""
    arr = numpy.asarray(img)
    t = torch.from_numpy(arr)
    t = t.transpose(0, 2) # 轉換維度 H,W,C 到 C,W,H
    t = t / 255.0 # 正規化數值使得範圍在 0 ~ 1
    return t

def map_box_to_resized_image(box, sw, sh):
    """把原始區域轉換到縮放後的圖片對應的區域"""
    x, y, w, h = box
    sw_new, sh_new, pad_w, pad_h = calc_resize_parameters(sw, sh)
    scale = IMAGE_SIZE[0] / sw_new
    x = int((x + pad_w) * scale)
    y = int((y + pad_h) * scale)
    w = int(w * scale)
    h = int(h * scale)
    if x + w > IMAGE_SIZE[0] or y + h > IMAGE_SIZE[1] or w == 0 or h == 0:
        return 0, 0, 0, 0
    return x, y, w, h

def map_box_to_original_image(box, sw, sh):
    """把縮放後圖片對應的區域轉換到縮放前的原始區域"""
    x, y, w, h = box
    sw_new, sh_new, pad_w, pad_h = calc_resize_parameters(sw, sh)
    scale = IMAGE_SIZE[0] / sw_new
    x = int(x / scale - pad_w)
    y = int(y / scale - pad_h)
    w = int(w / scale)
    h = int(h / scale)
    if x + w > sw or y + h > sh or x < 0 or y < 0 or w == 0 or h == 0:
        return 0, 0, 0, 0
    return x, y, w, h

def calc_iou(rect1, rect2):
    """計算兩個區域重疊部分 / 合併部分的比率 (intersection over union)"""
    x1, y1, w1, h1 = rect1
    x2, y2, w2, h2 = rect2
    xi = max(x1, x2)
    yi = max(y1, y2)
    wi = min(x1+w1, x2+w2) - xi
    hi = min(y1+h1, y2+h2) - yi
    if wi > 0 and hi > 0: # 有重疊部分
        area_overlap = wi*hi
        area_all = w1*h1 + w2*h2 - area_overlap
        iou = area_overlap / area_all
    else: # 沒有重疊部分
        iou = 0
    return iou

def calc_box_offset(candidate_box, true_box):
    """計算候選區域與實際區域的偏移值"""
    # 這裡計算出來的偏移值基於比例,而不受具體位置和大小影響
    # w h 使用 log 是為了減少過大的值的影響
    x1, y1, w1, h1 = candidate_box
    x2, y2, w2, h2 = true_box
    x_offset = (x2 - x1) / w1
    y_offset = (y2 - y1) / h1
    w_offset = math.log(w2 / w1)
    h_offset = math.log(h2 / h1)
    return (x_offset, y_offset, w_offset, h_offset)

def adjust_box_by_offset(candidate_box, offset):
    """根據偏移值調整候選區域"""
    x1, y1, w1, h1 = candidate_box
    x_offset, y_offset, w_offset, h_offset = offset
    x2 = min(IMAGE_SIZE[0]-1,  max(0, w1 * x_offset + x1))
    y2 = min(IMAGE_SIZE[1]-1,  max(0, h1 * y_offset + y1))
    w2 = min(IMAGE_SIZE[0]-x2, max(1, math.exp(w_offset) * w1))
    h2 = min(IMAGE_SIZE[1]-y2, max(1, math.exp(h_offset) * h1))
    return (x2, y2, w2, h2)

def merge_box(box_a, box_b):
    """合併兩個區域"""
    x1, y1, w1, h1 = box_a
    x2, y2, w2, h2 = box_b
    x = min(x1, x2)
    y = min(y1, y2)
    w = max(x1 + w1, x2 + w2) - x
    h = max(y1 + h1, y2 + h2) - y
    return (x, y, w, h)

def prepare_save_batch(batch, image_tensors, image_boxes_labels):
    """準備訓練 - 儲存單個批次的資料"""
    # 按索引值列表生成輸入和輸出 tensor 物件的函式
    def split_dataset(indices):
        image_in = []
        boxes_labels_out = {}
        for new_image_index, original_image_index in enumerate(indices.tolist()):
            image_in.append(image_tensors[original_image_index])
            boxes_labels_out[new_image_index] = image_boxes_labels[original_image_index]
        tensor_image_in = torch.stack(image_in) # 維度: B,C,W,H
        return tensor_image_in, boxes_labels_out

    # 切分訓練集 (80%),驗證集 (10%) 和測試集 (10%)
    random_indices = torch.randperm(len(image_tensors))
    training_indices = random_indices[:int(len(random_indices)*0.8)]
    validating_indices = random_indices[int(len(random_indices)*0.8):int(len(random_indices)*0.9):]
    testing_indices = random_indices[int(len(random_indices)*0.9):]
    training_set = split_dataset(training_indices)
    validating_set = split_dataset(validating_indices)
    testing_set = split_dataset(testing_indices)

    # 儲存到硬碟
    save_tensor(training_set, f"data/training_set.{batch}.pt")
    save_tensor(validating_set, f"data/validating_set.{batch}.pt")
    save_tensor(testing_set, f"data/testing_set.{batch}.pt")
    print(f"batch {batch} saved")

def prepare():
    """準備訓練"""
    # 資料集轉換到 tensor 以後會儲存在 data 資料夾下
    if not os.path.isdir("data"):
        os.makedirs("data")

    # 載入圖片和圖片對應的區域與分類列表
    # { 圖片名: [ 區域與分類, 區域與分類, .. ] }
    box_map = defaultdict(lambda: [])
    for filename in os.listdir(IMAGE_DIR):
        xml_path = os.path.join(ANNOTATION_DIR, filename.split(".")[0] + ".xml")
        if not os.path.isfile(xml_path):
            continue
        tree = ET.ElementTree(file=xml_path)
        objects = tree.findall("object")
        for obj in objects:
            class_name = obj.find("name").text
            x1 = int(obj.find("bndbox/xmin").text)
            x2 = int(obj.find("bndbox/xmax").text)
            y1 = int(obj.find("bndbox/ymin").text)
            y2 = int(obj.find("bndbox/ymax").text)
            box_map[filename].append((x1, y1, x2-x1, y2-y1, CLASSES_MAPPING[class_name]))

    # 儲存圖片和圖片對應的分類與區域列表
    batch_size = 20
    batch = 0
    image_tensors = [] # 圖片列表
    image_boxes_labels = {} # 圖片對應的真實區域與分類列表,和候選區域與區域偏移
    for filename, original_boxes_labels in box_map.items():
        image_path = os.path.join(IMAGE_DIR, filename)
        with Image.open(image_path) as img_original: # 載入原始圖片
            sw, sh = img_original.size # 原始圖片大小
            img = resize_image(img_original) # 縮放圖片
            image_index = len(image_tensors) # 圖片在批次中的索引值
            image_tensors.append(image_to_tensor(img)) # 新增圖片到列表
            true_boxes_labels = [] # 圖片對應的真實區域與分類列表
        # 新增真實區域與分類列表
        for box_label in original_boxes_labels:
            x, y, w, h, label = box_label
            x, y, w, h = map_box_to_resized_image((x, y, w, h), sw, sh) # 縮放實際區域
            if w < 10 or h < 10:
                continue # 縮放後區域過小
            # 檢查計算是否有問題
            # child_img = img.copy().crop((x, y, x+w, y+h))
            # child_img.save(f"{filename}_{x}_{y}_{w}_{h}_{label}.png")
            true_boxes_labels.append((label, (x, y, w, h)))
        # 如果圖片中的所有區域都過小則跳過
        if not true_boxes_labels:
            image_tensors.pop()
            image_index = len(image_tensors)
            continue
        # 根據錨點列表尋找候選區域,並計算區域偏移
        actual_rpn_labels = []
        actual_rpn_labels_mask = []
        actual_rpn_offsets = []
        actual_rpn_offsets_mask = []
        positive_index_set = set()
        for index, anchor_box in enumerate(MyModel.Anchors):
            # 如果候選區域和任意一個實際區域重疊率大於閾值,則認為是正樣本
            # 如果候選區域和所有實際區域重疊率都小於閾值,則認為是負樣本
            # 重疊率介於兩個閾值之間的區域不參與學習
            iou_list = [ calc_iou(anchor_box, true_box) for (_, true_box) in true_boxes_labels ]
            positive_index = next((index for index, iou in enumerate(iou_list) if iou > IOU_POSITIVE_THRESHOLD), None)
            is_negative = all(iou < IOU_NEGATIVE_THRESHOLD for iou in iou_list)
            if positive_index is not None:
                positive_index_set.add(positive_index)
                actual_rpn_labels.append(1)
                actual_rpn_labels_mask.append(index)
                # 只有包含物件的區域參需要調整偏移
                true_box = true_boxes_labels[positive_index][1]
                actual_rpn_offsets.append(calc_box_offset(anchor_box, true_box))
                actual_rpn_offsets_mask.append(index)
            elif is_negative:
                actual_rpn_labels.append(0)
                actual_rpn_labels_mask.append(index)
        # 輸出找不到候選區域的真實區域,調整錨點生成引數時使用
        # for index in range(len(true_boxes_labels)):
        #    if index not in positive_index_set:
        #        print(true_boxes_labels[index][1])
        # print("-----")
        # 如果一個候選區域都找不到則跳過
        if not positive_index_set:
            image_tensors.pop()
            image_index = len(image_tensors)
            continue
        image_boxes_labels[image_index] = (
            true_boxes_labels,
            torch.tensor(actual_rpn_labels, dtype=torch.long),
            torch.tensor(actual_rpn_labels_mask, dtype=torch.long),
            torch.tensor(actual_rpn_offsets, dtype=torch.float),
            torch.tensor(actual_rpn_offsets_mask, dtype=torch.long))
        # 儲存批次
        if len(image_tensors) >= batch_size:
            prepare_save_batch(batch, image_tensors, image_boxes_labels)
            image_tensors.clear()
            image_boxes_labels.clear()
            batch += 1
    # 儲存剩餘的批次
    if len(image_tensors) > 10:
        prepare_save_batch(batch, image_tensors, image_boxes_labels)

def train():
    """開始訓練"""
    # 建立模型例項
    model = MyModel().to(device)

    # 建立多工損失計算器
    loss_function = MyModel.loss_function

    # 建立引數調整器
    optimizer = torch.optim.Adam(model.parameters())

    # 記錄訓練集和驗證集的正確率變化
    training_rpn_accuracy_history = []
    training_cls_accuracy_history = []
    validating_rpn_accuracy_history = []
    validating_cls_accuracy_history = []

    # 記錄最高的驗證集正確率
    validating_rpn_accuracy_highest = -1
    validating_rpn_accuracy_highest_epoch = 0
    validating_cls_accuracy_highest = -1
    validating_cls_accuracy_highest_epoch = 0

    # 讀取批次的工具函式
    def read_batches(base_path):
        for batch in itertools.count():
            path = f"{base_path}.{batch}.pt"
            if not os.path.isfile(path):
                break
            x, y = load_tensor(path)
            yield x.to(device), y

    # 計算正確率的工具函式
    calc_accuracy = MyModel.calc_accuracy

    # 開始訓練過程
    for epoch in range(1, 10000):
        print(f"epoch: {epoch}")

        # 根據訓練集訓練並修改引數
        # 切換模型到訓練模式,將會啟用自動微分,批次正規化 (BatchNorm) 與 Dropout
        model.train()
        training_rpn_accuracy_list = []
        training_cls_accuracy_list = []
        for batch_index, batch in enumerate(read_batches("data/training_set")):
            # 劃分輸入和輸出
            batch_x, batch_y = batch
            # 計算預測值
            predicted = model(batch_x)
            # 計算損失
            loss = loss_function(predicted, batch_y)
            # 從損失自動微分求導函式值
            loss.backward()
            # 使用引數調整器調整引數
            optimizer.step()
            # 清空導函式值
            optimizer.zero_grad()
            # 記錄這一個批次的正確率,torch.no_grad 代表臨時禁用自動微分功能
            with torch.no_grad():
                training_batch_rpn_accuracy, training_batch_cls_accuracy = calc_accuracy(batch_y, predicted)
            # 輸出批次正確率
            training_rpn_accuracy_list.append(training_batch_rpn_accuracy)
            training_cls_accuracy_list.append(training_batch_cls_accuracy)
            print(f"epoch: {epoch}, batch: {batch_index}: " +
                f"batch rpn accuracy: {training_batch_rpn_accuracy}, cls accuracy: {training_batch_cls_accuracy}")
        training_rpn_accuracy = sum(training_rpn_accuracy_list) / len(training_rpn_accuracy_list)
        training_cls_accuracy = sum(training_cls_accuracy_list) / len(training_cls_accuracy_list)
        training_rpn_accuracy_history.append(training_rpn_accuracy)
        training_cls_accuracy_history.append(training_cls_accuracy)
        print(f"training rpn accuracy: {training_rpn_accuracy}, cls accuracy: {training_cls_accuracy}")

        # 檢查驗證集
        # 切換模型到驗證模式,將會禁用自動微分,批次正規化 (BatchNorm) 與 Dropout
        model.eval()
        validating_rpn_accuracy_list = []
        validating_cls_accuracy_list = []
        for batch in read_batches("data/validating_set"):
            batch_x, batch_y = batch
            predicted = model(batch_x)
            validating_batch_rpn_accuracy, validating_batch_cls_accuracy = calc_accuracy(batch_y, predicted)
            validating_rpn_accuracy_list.append(validating_batch_rpn_accuracy)
            validating_cls_accuracy_list.append(validating_batch_cls_accuracy)
        validating_rpn_accuracy = sum(validating_rpn_accuracy_list) / len(validating_rpn_accuracy_list)
        validating_cls_accuracy = sum(validating_cls_accuracy_list) / len(validating_cls_accuracy_list)
        validating_rpn_accuracy_history.append(validating_rpn_accuracy)
        validating_cls_accuracy_history.append(validating_cls_accuracy)
        print(f"validating rpn accuracy: {validating_rpn_accuracy}, cls accuracy: {validating_cls_accuracy}")

        # 記錄最高的驗證集正確率與當時的模型狀態,判斷是否在 20 次訓練後仍然沒有重新整理記錄
        if validating_rpn_accuracy > validating_rpn_accuracy_highest:
            validating_rpn_accuracy_highest = validating_rpn_accuracy
            validating_rpn_accuracy_highest_epoch = epoch
            save_tensor(model.state_dict(), "model.pt")
            print("highest rpn validating accuracy updated")
        elif validating_cls_accuracy > validating_cls_accuracy_highest:
            validating_cls_accuracy_highest = validating_cls_accuracy
            validating_cls_accuracy_highest_epoch = epoch
            save_tensor(model.state_dict(), "model.pt")
            print("highest cls validating accuracy updated")
        elif (epoch - validating_rpn_accuracy_highest_epoch > 20 and
            epoch - validating_cls_accuracy_highest_epoch > 20):
            # 在 20 次訓練後仍然沒有重新整理記錄,結束訓練
            print("stop training because highest validating accuracy not updated in 20 epoches")
            break

    # 使用達到最高正確率時的模型狀態
    print(f"highest rpn validating accuracy: {validating_rpn_accuracy_highest}",
        f"from epoch {validating_rpn_accuracy_highest_epoch}")
    print(f"highest cls validating accuracy: {validating_cls_accuracy_highest}",
        f"from epoch {validating_cls_accuracy_highest_epoch}")
    model.load_state_dict(load_tensor("model.pt"))

    # 檢查測試集
    testing_rpn_accuracy_list = []
    testing_cls_accuracy_list = []
    for batch in read_batches("data/testing_set"):
        batch_x, batch_y = batch
        predicted = model(batch_x)
        testing_batch_rpn_accuracy, testing_batch_cls_accuracy = calc_accuracy(batch_y, predicted)
        testing_rpn_accuracy_list.append(testing_batch_rpn_accuracy)
        testing_cls_accuracy_list.append(testing_batch_cls_accuracy)
    testing_rpn_accuracy = sum(testing_rpn_accuracy_list) / len(testing_rpn_accuracy_list)
    testing_cls_accuracy = sum(testing_cls_accuracy_list) / len(testing_cls_accuracy_list)
    print(f"testing rpn accuracy: {testing_rpn_accuracy}, cls accuracy: {testing_cls_accuracy}")

    # 顯示訓練集和驗證集的正確率變化
    pyplot.plot(training_rpn_accuracy_history, label="training_rpn_accuracy")
    pyplot.plot(training_cls_accuracy_history, label="training_cls_accuracy")
    pyplot.plot(validating_rpn_accuracy_history, label="validating_rpn_accuracy")
    pyplot.plot(validating_cls_accuracy_history, label="validating_cls_accuracy")
    pyplot.ylim(0, 1)
    pyplot.legend()
    pyplot.show()

def eval_model():
    """使用訓練好的模型"""
    # 建立模型例項,載入訓練好的狀態,然後切換到驗證模式
    model = MyModel().to(device)
    model.load_state_dict(load_tensor("model.pt"))
    model.eval()

    # 詢問圖片路徑,並顯示所有可能是人臉的區域
    while True:
        try:
            image_path = input("Image path: ")
            if not image_path:
                continue
            # 構建輸入
            with Image.open(image_path) as img_original: # 載入原始圖片
                sw, sh = img_original.size # 原始圖片大小
                img = resize_image(img_original) # 縮放圖片
                img_output = img_original.copy() # 複製圖片,用於後面新增標記
                tensor_in = image_to_tensor(img)
            # 預測輸出
            cls_result = model(tensor_in.unsqueeze(0).to(device))[-1][0]
            # 合併重疊的結果區域, 結果是 [ [標籤列表, 合併後的區域], ... ]
            final_result = []
            for label, box in cls_result:
                for index in range(len(final_result)):
                    exists_labels, exists_box = final_result[index]
                    if calc_iou(box, exists_box) > IOU_MERGE_THRESHOLD:
                        exists_labels.append(label)
                        final_result[index] = (exists_labels, merge_box(box, exists_box))
                        break
                else:
                    final_result.append(([label], box))
            # 合併標籤 (重疊區域的標籤中數量最多的分類為最終分類)
            for index in range(len(final_result)):
                labels, box = final_result[index]
                final_label = Counter(labels).most_common(1)[0][0]
                final_result[index] = (final_label, box)
            # 標記在圖片上
            draw = ImageDraw.Draw(img_output)
            for label, box in final_result:
                x, y, w, h = map_box_to_original_image(box, sw, sh)
                draw.rectangle((x, y, x+w, y+h), outline="#FF0000")
                draw.text((x, y-10), CLASSES[label], fill="#FF0000")
                print((x, y, w, h), CLASSES[label])
            img_output.save("img_output.png")
            print("saved to img_output.png")
            print()
        except Exception as e:
            print("error:", e)

def main():
    """主函式"""
    if len(sys.argv) < 2:
        print(f"Please run: {sys.argv[0]} prepare|train|eval")
        exit()

    # 給隨機數生成器分配一個初始值,使得每次執行都可以生成相同的隨機數
    # 這是為了讓過程可重現,你也可以選擇不這樣做
    random.seed(0)
    torch.random.manual_seed(0)

    # 根據命令列引數選擇操作
    operation = sys.argv[1]
    if operation == "prepare":
        prepare()
    elif operation == "train":
        train()
    elif operation == "eval":
        eval_model()
    else:
        raise ValueError(f"Unsupported operation: {operation}")

if __name__ == "__main__":
    main()

執行以下命令開始訓練:

python3 example.py prepare
python3 example.py train

最終輸出如下:

epoch: 101, batch: 30: batch rpn accuracy: 0.9999998976070061, cls accuracy: 0.9114583333333333
epoch: 101, batch: 31: batch rpn accuracy: 0.9834558104401839, cls accuracy: 0.8140625
epoch: 101, batch: 32: batch rpn accuracy: 0.9999098026259949, cls accuracy: 0.7739583333333333
epoch: 101, batch: 33: batch rpn accuracy: 0.9998011454364403, cls accuracy: 0.8216517857142858
epoch: 101, batch: 34: batch rpn accuracy: 0.9968102716843542, cls accuracy: 0.7961309523809523
epoch: 101, batch: 35: batch rpn accuracy: 0.9992402167888915, cls accuracy: 0.9169642857142857
epoch: 101, batch: 36: batch rpn accuracy: 0.9991754689754888, cls accuracy: 0.784375
epoch: 101, batch: 37: batch rpn accuracy: 0.9998954174868623, cls accuracy: 0.808531746031746
epoch: 101, batch: 38: batch rpn accuracy: 0.999810537169184, cls accuracy: 0.8928571428571429
epoch: 101, batch: 39: batch rpn accuracy: 0.9993760622446838, cls accuracy: 0.7447916666666667
epoch: 101, batch: 40: batch rpn accuracy: 0.9990286666127914, cls accuracy: 0.8565972222222223
epoch: 101, batch: 41: batch rpn accuracy: 0.9999998978468275, cls accuracy: 0.8012820512820512
training rpn accuracy: 0.9992436053003302, cls accuracy: 0.8312847933023624
validating rpn accuracy: 0.89010891321815, cls accuracy: 0.6757137703566275
stop training because highest validating accuracy not updated in 20 epoches
highest rpn validating accuracy: 0.951476186351423 from epoch 63
highest cls validating accuracy: 0.707979883872741 from epoch 80
testing rpn accuracy: 0.9250985286757772, cls accuracy: 0.7238060880918024

cls accuracy 代表可以識別出多少包含物件的區域並且正確判斷它的分類,雖然只有 70% 左右但實際效果還是不錯的,如果有更多視訊記憶體可以增強 CNN 模型 (例如使用 Resnet-50) 與加大 IMAGE_SIZE

訓練集和驗證集的正確率變化如下:

執行以下命令,再輸入圖片路徑可以使用學習好的模型識別圖片:

python3 example.py eval

以下是部分識別結果:

效果還行吧?,順道一提每張圖片的識別時間大約在 0.05 ~ 0.06 秒之間,相對於 Fast-RCNN 快了接近 10 倍,用在視訊上大約可以支援 20fps 左右 (我機器配置比較低,4 核 CPU + GTX1650,高配機器可以更快?)。

寫在最後

這篇介紹的 Faster-RCNN 效果明顯比之前介紹的 RCNN 與 Fast-RCNN 更好,但還是有缺點的,如果物件相對於圖片很小或者很大,那麼物件與錨點的各個形狀的重疊率都會比較低,導致無法識別出來。下一篇介紹的 YOLO 模型一定程度上改善了這個問題,但接下來一頭半個月我估計都沒時間寫,想看的耐心等吧?。

相關文章