OpenCV + sklearnSVM 實現手寫數字分割和識別

凪风sama發表於2024-06-17

這學期機器學習考核方式以大作業的形式進行考核,而且只能使用一些傳統的機器學習演算法。
綜合再三,選擇了自己比較熟悉的MNIST資料集以及OpenCV來完成手寫數字的分割和識別作為大作業。

1. 資料集準備

MNIST資料集是一個手寫數字的資料庫,包含60000張訓練圖片和10000張測試圖片,每張圖片大小為28x28畫素,每張圖片都是一個
灰度圖,畫素取值範圍在0-255之間。

這裡使用pytorch的torchvision.datasets模組來讀取MNIST資料集。

from torchvision import datasets
mnist_set = datasets.MNIST(root="./MNIST", train=True, download=True)

具體引數說明請自行搜尋。注意若donwload=True,則torchvision會透過內建連結自動下載資料集,
但是有時會失效。因此可以自己去網路上下載並解壓後排列成指定檔案樹,如下

MNIST
├── MNSIT
│   ├── raw
│   │   ├── t10k-images-idx3-ubyte.gz
│   │   ├── t10k-labels-idx1-ubyte.gz
│   │   ├── train-images-idx3-ubyte.gz
│   │   └── train-labels-idx1-ubyte.gz

然後使用如下語句去讀取資料集

img, target = minst_set[0]

其中每個img型別為PILimage,target型別為int,代表該圖片對應的數字。

但是在餵給SVM訓練時需要的是[batch_size, data]大小的numpy陣列,因此需要做一些預處理

   x_, y_ = list(zip(*([(np.array(img).reshape(28*28), target) for img, target in mnist_set])))

上面的語句實現了將MNIST資料集轉換成numpy陣列的形式,其中x_是每個成員為[1, 784]的numpy陣列,y_為對應的數字所組成的列表。

2. SVM訓練

支援向量機(support vector machine,SVM)是經典的機器學習演算法,其透過選取兩個n維支援向量(support vector)之間的n維超平面來對兩類物件進行二分類。而專注於分類的SVM又稱作Support Vector Classification,SVC。

求解SVM是一個很複雜的問題,但是萬幸的是sklearn中有封裝的很好的模組,可以很簡單的直接使用

from sklearn.svm import SVC

svc = SVC(kernel='rbf', C=1)
 
svc.fit(x_, y_)

其中fit介面接受兩個引數,第一個引數為訓練資料[batch_size, data],第二個引數為訓練標籤[batch_size,1]。
SVC的建構函式如下

SVC(C=1.0, kernel='rbf', degree=3, gamma='scale', coef0=0.0, shrinking=True, probability=False, tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape='ovr', random_state=None)

比較重要的引數有kernel,C,decison_function_shape等。

  • kernel引數指定了核函式,常用的有linear,poly,rbf,sigmoid等。
  • C為懲罰係數,C越大,對誤分類的懲罰越大,模型越保守,C越小,對誤分類的懲罰越小,模型越寬鬆,也就是較大的C在訓練集上會有更高的正確率,較小的C會容許噪聲的存在,泛化能力較強。
  • decision_function_shape引數指定了決策函式的形狀,ovr表示one-vs-rest,ovo表示one-vs-one,具體的意思可以網路查閱

4. 數字分割

數字分割是指將影像中的數字部分分割出來,然後一個一個餵給SVM進行分類

這裡就是使用opencv對拍攝的影像進行輪廓提取後擬合外接矩形,藉此來提取數字部分的ROI。

這裡選擇進行Canny邊緣檢測後去進行輪廓提取,然後擬合外接矩形,因為相較於直接二值化後去提取數字部分的ROI,
邊緣檢測對數字與紙張的邊界更加敏感,即便在光照不均勻的情況下,也能較好的提取出數字的邊緣。魯棒性強。

5. 雜項與程式碼

這裡還有一些雜項,比如儲存模型,載入模型

使用pickle模組對訓練好的模型物件進行序列化儲存與載入,可以將訓練好的模型儲存到本地,以便後續使用。

最後貼出程式碼

程式碼
import os.path
import cv2
import numpy as np
from matplotlib import pyplot as plt
from torchvision import datasets
from torchvision import transforms
from sklearn import svm
from sklearn import preprocessing
from sklearnex import patch_sklearn
import pickle
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import learning_curve

'''
    @brief  載入MNIST資料集並轉換格式成二值圖
    
    @param train: 是否為訓練集
    @param data_enhance: 是否進行資料增強
    
    @return 二值圖集和標籤集
'''
def LoadMnistDataset(train=True, data_enhance=False):
    mnist_set = datasets.MNIST(root="./MNIST", train=train, download=True)
    x_, y_ = list(zip(*([(np.array(img), target) for img, target in mnist_set])))
    sets_raw = []
    sets_r20 = []
    sets_invr20 = []
    y = []
    y_r20 = []
    y_invr20 = []
    sets = []
    matrix_r20 = cv2.getRotationMatrix2D((14, 14), 25, 1.0)
    matrix_invr20 = cv2.getRotationMatrix2D((14, 14), -25, 1.0)
    select = 0
    for idx in range(len(x_)):
        # 對影像進行二值化以及資料增強
        _, img = cv2.threshold(x_[idx], 255, 255, cv2.THRESH_OTSU)
        sets_raw.append(np.array(img.data).reshape(784))
        y.append(y_[idx])
        if data_enhance:
            if select % 2 == 0:
                img_r20 = ~cv2.warpAffine(~img, matrix_r20, (28, 28), borderValue=(255, 255, 255))
                sets_r20.append(np.array(img_r20.data).reshape(784))
                y_r20.append(y_[idx])
            else:
                img_invr20 = ~cv2.warpAffine(~img, matrix_invr20, (28, 28), borderValue=(255, 255, 255))
                sets_invr20.append(np.array(img_invr20.data).reshape(784))
                y_invr20.append(y_[idx])
            select += 1

    # 資料增強
    sets = sets_raw + sets_r20 + sets_invr20
    sets = np.array(sets)
    print(sets.shape)
    if data_enhance:
        y = y + y_r20 + y_invr20
    return sets, y

'''
    @brief  儲存SVM模型
    
    @param svc_model: SVM模型 
    @param file_path: 模型儲存路徑,預設為./SVC
    
    @return None
'''
def SaveSvcModel(svc_model, file_path="./SVC"):
    with open(file_path, 'wb') as fs:
        pickle.dump(svc_model, fs)

'''
     @brief  載入SVM模型
     
     @param file_path: 模型儲存路徑,預設為./SVC
     
     @return SVM模型
'''
def LoadSvcModel(file_path="./SVC"):
    if not os.path.exists(file_path):
        assert "Model Do Not Exist"
    with open(file_path, 'rb') as fs:
        svc_model = pickle.load(fs)
    return svc_model

'''
     @brief  訓練SVM模型
     
     @param c: SVM引數C
     @param enhance: 是否進行資料增強
     
     @return acc: 在測試集上的準確率
             svc_model: SVM模型
'''
def TrainSvc(c, enhance):
    # 讀取資料集,訓練集及測試集
    images_train, targets_train = LoadMnistDataset(train=True, data_enhance=enhance)
    images_test, targets_test = LoadMnistDataset(train=False, data_enhance=enhance)

    # 訓練
    svc_model = svm.SVC(C=c,kernel='rbf', decision_function_shape='ovr')
    svc_model.fit(images_train, targets_train)

    # 在測試集上測試準確度

    res = svc_model.predict(images_test)
    correct = (res == targets_test).sum()
    accuracy = correct / len(images_test)
    print(f"測試集上的準確率為{accuracy * 100}%")
    return svc_model

'''
     @brief  預處理比較粗的字型
     
     @param image: 輸入影像
     @:param show: 是否顯示預處理後的影像
     @:param thresh: 二值化閾值
     
     @return 預處理後的影像資料
'''
def PreProcessFatFont(image, show=False):
    # 白底黑字轉黑底白字
    pre_ = ~image

    # 轉單通道灰度
    pre_ = cv2.cvtColor(pre_, cv2.COLOR_BGR2GRAY)
    # 二值化
    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    # resize後新增黑色邊框,親測可提高識別率
    pre_ = cv2.resize(pre_, (112, 112))
    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    back = np.zeros((300, 300), np.uint8)
    back[29:141, 29:141] = pre_
    pre_ = back

    if show:
        cv2.imshow("show", pre_)
        cv2.waitKey(0)

    # 做一次開運算(腐蝕 + 膨脹)
    kernel = np.ones((2, 2), np.uint8)
    pre_ = cv2.erode(pre_, kernel, iterations=1)
    kernel = np.ones((3, 3), np.uint8)
    pre_ = cv2.dilate(pre_, kernel, iterations=1)

    # 第二次resize
    pre_ = cv2.resize(pre_, (56, 56))
    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    # 做一次開運算(腐蝕 + 膨脹)
    kernel = np.ones((2, 2), np.uint8)
    pre_ = cv2.erode(pre_, kernel, iterations=1)
    kernel = np.ones((3, 3), np.uint8)
    pre_ = cv2.dilate(pre_, kernel, iterations=1)

    # resize成輸入規格
    pre_ = cv2.resize(pre_, (28, 28))
    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    # 轉換為SVM的輸入格式
    pre_ = np.array(pre_).flatten().reshape(1, -1)
    return pre_

'''
     @brief  預處理細的字型
     
     @param image: 輸入影像
     @param show: 是否顯示預處理後的影像
     @param thresh: 二值化閾值
     
     
     @return 預處理後的影像資料
'''
def PreProcessThinFont(image, show=False):
    # 白底黑字轉黑底白字
    pre_ = ~image

    # 轉灰度圖
    pre_ = cv2.cvtColor(pre_, cv2.COLOR_BGR2GRAY)

    # 增加黑色邊框
    pre_ = cv2.resize(pre_, (112, 112))
    _, pre_ = cv2.threshold(pre_,thresh=0, maxval=255, type=cv2.THRESH_OTSU)
    back = np.zeros((170, 170), dtype=np.uint8) # 這裡不指明型別會導致後續矩陣強轉為float64,無法使用大津法閾值
    back[29:141, 29:141] = pre_
    pre_ = back

    if show:
        cv2.imshow("show", pre_)
        cv2.waitKey(0)

    # 對細字型先膨脹一下
    kernel = np.ones((3, 3), np.uint8)
    pre_ = cv2.dilate(pre_, kernel, iterations=2)



    # 第二次resize
    pre_ = cv2.resize(pre_, (56, 56))

    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    # 做一次開運算(腐蝕 + 膨脹)
    kernel = np.ones((2, 2), np.uint8)
    pre_ = cv2.erode(pre_, kernel, iterations=1)
    kernel = np.ones((3, 3), np.uint8)
    pre_ = cv2.dilate(pre_, kernel, iterations=1)

    # resize成輸入規格
    pre_ = cv2.resize(pre_, (28, 28))
    _, pre_ = cv2.threshold(pre_, thresh=0, maxval=255, type=cv2.THRESH_OTSU)

    # 轉換為SVM輸入格式
    pre_ = np.array(pre_).flatten().reshape(1, -1)

    return pre_

'''
     @brief  在空白背景上顯示提取出的roi
     
     @param res_list: roi列表
     
     @return None
'''
def ShowRoi(res_list):
    back = 255 * np.ones((1000, 1500, 3), dtype=np.uint8)
    # 圖片x軸偏移量
    tlx = 0

    for roi in res_list:
        if tlx + roi.shape[1] > back.shape[1]:
            break
        # 每次在原圖上加上一個roi
        back[0:roi.shape[0], tlx:tlx + roi.shape[1], :] = roi
        tlx += roi.shape[1]

    cv2.imshow("show", back)
    cv2.waitKey(0)

'''
     @brief  尋找數字輪廓並提取roi
     
     @param src: 輸入影像
     @param thin: 是否為細字型
     @param thresh: 二值化閾值
     
     @return roi列表
'''
def FindNumbers(src, thin=True):
    # 複製
    dst = src.copy()
    paint = src.copy()
    roi = src.copy()
    dst = ~dst

    # 預處理
    paint = cv2.resize(paint, (448, 448))
    dst = cv2.resize(dst, (448, 448))

    # 記錄縮放比例,後來看這一步好像沒啥意義
    fx = src.shape[1] / 448
    fy = src.shape[0] / 448

    # 轉單通道
    dst = cv2.cvtColor(dst, cv2.COLOR_BGR2GRAY)

    # 邊緣檢測後二值化,直接二值化的話由於採光不同的原因灰度直方圖峰與峰之間可能會差距過大,導致二值圖的分割不準確
    # 而邊緣檢測對畫素突變更加敏感,因此採用Canny邊緣檢測後二值化
    cv2.Canny(dst, 200, 200, dst)

    # 對於平常筆寫的字太細,膨脹一下
    if thin:
        kernel = np.ones((5, 5), np.uint8)
        dst = cv2.dilate(dst, kernel, iterations=1)

    # 尋找外圍輪廓
    contours, _ = cv2.findContours(dst, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 提取roi
    roi_list = []
    rect_list = []
    for contour in contours:
        rect = cv2.boundingRect(contour)
        if not ((rect[2] * rect[3] < 400 or rect[2] * rect[3] > 448 * 448 / 2.5) or (rect[3] < rect[2])):
            cv2.rectangle(paint, rect, (255, 0, 0), 1)
            x_min = rect[0] * fx
            x_max = (rect[0] + rect[2]) * fx
            y_min = rect[1] * fy
            y_max = (rect[1] + rect[3]) * fy
            roi_list.append(roi[int(y_min):int(y_max), int(x_min):int(x_max)].copy())
            rect_list.append(rect)
    return paint, roi_list, rect_list

'''
     @brief  以txt形式顯示資料
     
     @param data: 資料集
     
     @return None   
'''
def ShowDataTxt(data):
    print("----------------------------------------------------------")
    for i in range(28):
        for j in range(28):
            print(0 if data[0][i * 28 + j] == 255 else 1, end='')
        print('\n')
    print("----------------------------------------------------------")



if __name__ == "__main__":
    # 載入
    patch_sklearn()
    model_path = "./SVC_C1_enhance.pkl"

    if os.path.exists(model_path):
        print("Model Exist, Load Form Serialization")
        model = LoadSvcModel(model_path)
    else:
        print("Model Do Not Exist, Train")

        # 訓練
        model = TrainSvc(1, False)


        # 儲存
        SaveSvcModel(model, model_path)

    # 測試
    paint, nums, rects = FindNumbers(cv2.imread("test_final.jpg"))
    predict_nums = []
    for img in nums:
        data = PreProcessThinFont(img, show=False)
       # ShowDataTxt(data)
        predict_nums.append(model.predict(data)[0])
    for i in range(len(predict_nums)):
        cv2.putText(paint,str(predict_nums[i]), (rects[i][0], rects[i][1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
    cv2.imshow("show", paint)
    cv2.waitKey(0)

給出幾個識別後的效果:
image

相關文章