PaddleOCR手寫文字識別模型訓練(摘抄所得,非原創)

yesyes1發表於2024-03-14

1. 安裝環境

# 首先git官方的PaddleOCR專案,安裝需要的依賴
git clone https://github.com/PaddlePaddle/PaddleOCR.git
cd PaddleOCR
pip install -r requirements.txt

2. 資料準備

本專案使用公開的手寫文字識別資料集,包含Chinese OCR, 中科院自動化研究所-手寫中文資料集CASIA-HWDB2.x,以及由中科院手寫資料和網上開源資料合併組合的資料集等,該專案已經掛載處理好的資料集,可直接下載使用進行訓練。

下載並解壓資料
tar -xf hw_data.tar

3. 模型訓練

3.1 下載預訓練模型

首先需要下載我們需要的PP-OCRv3識別預訓練模型,更多選擇請自行選擇其他的文字識別模型

# 使用該指令下載需要的預訓練模型
wget -P ./pretrained_models/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_train.tar
# 解壓預訓練模型檔案
tar -xf ./pretrained_models/ch_PP-OCRv3_rec_train.tar -C pretrained_models

3.2 修改配置檔案

我們使用configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml,主要修改訓練輪數和學習率參相關引數,設定預訓練模型路徑,設定資料集路徑。 另外,batch_size可根據自己機器視訊記憶體大小進行調整。 具體修改如下幾個地方:

  epoch_num: 100 # 訓練epoch數
  save_model_dir: ./output/ch_PP-OCR_v3_rec
  save_epoch_step: 10
  eval_batch_step: [0, 100] # 評估間隔,每隔100step評估一次
  pretrained_model: ./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy  # 預訓練模型路徑


  lr:
    name: Cosine # 修改學習率衰減策略為Cosine
    learning_rate: 0.0001 # 修改fine-tune的學習率
    warmup_epoch: 2 # 修改warmup輪數

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data # 訓練集圖片路徑
    ext_op_transform_idx: 1
    label_file_list:
    - ./train_data/chineseocr-data/rec_hand_line_all_label_train.txt # 訓練集標籤
    - ./train_data/handwrite/HWDB2.0Train_label.txt
    - ./train_data/handwrite/HWDB2.1Train_label.txt
    - ./train_data/handwrite/HWDB2.2Train_label.txt
    - ./train_data/handwrite/hwdb_ic13/handwriting_hwdb_train_labels.txt
    - ./train_data/handwrite/HW_Chinese/train_hw.txt
    ratio_list:
    - 0.1
    - 1.0
    - 1.0
    - 1.0
    - 0.02
    - 1.0
  loader:
    shuffle: true
    batch_size_per_card: 64
    drop_last: true
    num_workers: 4
Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data # 測試集圖片路徑
    label_file_list:
    - ./train_data/chineseocr-data/rec_hand_line_all_label_val.txt # 測試集標籤
    - ./train_data/handwrite/HWDB2.0Test_label.txt
    - ./train_data/handwrite/HWDB2.1Test_label.txt
    - ./train_data/handwrite/HWDB2.2Test_label.txt
    - ./train_data/handwrite/hwdb_ic13/handwriting_hwdb_val_labels.txt
    - ./train_data/handwrite/HW_Chinese/test_hw.txt
  loader:
    shuffle: false
    drop_last: false
    batch_size_per_card: 64
    num_workers: 4

由於資料集大多是長文字,因此需要註釋掉下面的資料增廣策略,以便訓練出更好的模型。

- RecConAug:
    prob: 0.5
    ext_data_num: 2
    image_shape: [48, 320, 3]

3.3 開始訓練

我們使用上面修改好的配置檔案configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml,預訓練模型,資料集路徑,學習率,訓練輪數等都已經設定完畢後,可以使用下面命令開始訓練。

# 開始訓練識別模型
python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml

4. 模型評估

在訓練之前,我們可以直接使用下面命令來評估預訓練模型的效果:

# 評估預訓練模型
python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy"
[2022/07/14 10:46:22] ppocr INFO: load pretrain successful from ./pretrained_models/ch_PP-OCRv3_rec_train/best_accuracy
eval model:: 100%|████████████████████████████| 687/687 [03:29<00:00,  3.27it/s]
[2022/07/14 10:49:52] ppocr INFO: metric eval ***************
[2022/07/14 10:49:52] ppocr INFO: acc:0.03724954461811258
[2022/07/14 10:49:52] ppocr INFO: norm_edit_dis:0.4859541065843199
[2022/07/14 10:49:52] ppocr INFO: Teacher_acc:0.0371584699368947
[2022/07/14 10:49:52] ppocr INFO: Teacher_norm_edit_dis:0.48718814890536477
[2022/07/14 10:49:52] ppocr INFO: fps:947.8562684823883

可以看出,直接載入預訓練模型進行評估,效果較差,因為預訓練模型並不是基於手寫文字進行單獨訓練的,所以我們需要基於預訓練模型進行finetune。
訓練完成後,可以進行測試評估,評估命令如下:

# 評估finetune效果
python tools/eval.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_rec/best_accuracy"

評估結果如下,可以看出識別準確率為54.3%。

[2022/07/14 10:54:06] ppocr INFO: metric eval ***************
[2022/07/14 10:54:06] ppocr INFO: acc:0.5430100180913
[2022/07/14 10:54:06] ppocr INFO: norm_edit_dis:0.9203322593158589
[2022/07/14 10:54:06] ppocr INFO: Teacher_acc:0.5401183969626324
[2022/07/14 10:54:06] ppocr INFO: Teacher_norm_edit_dis:0.919827504507755
[2022/07/14 10:54:06] ppocr INFO: fps:928.948733797251

如需獲取已訓練模型,請加入PaddleX官方交流頻道,獲取20G OCR學習大禮包(內含《動手學OCR》電子書、課程回放影片、前沿論文等重磅資料)

  • PaddleX官方交流頻道:https://aistudio.baidu.com/community/channel/610

將下載或訓練完成的模型放置在對應目錄下即可完成模型推理

5. 模型匯出推理

訓練完成後,可以將訓練模型轉換成inference模型。inference 模型會額外儲存模型的結構資訊,在預測部署、加速推理上效能優越,靈活方便,適合於實際系統整合。

5.1 模型匯出

匯出命令如下:

# 轉化為推理模型
python tools/export_model.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml -o Global.pretrained_model="./output/ch_PP-OCR_v3_rec/best_accuracy" Global.save_inference_dir="./inference/rec_ppocrv3/"

5.2 模型推理

匯出模型後,可以使用如下命令進行推理預測:

# 推理預測
python tools/infer/predict_rec.py --image_dir="train_data/handwrite/HWDB2.0Test_images/104-P16_4.jpg" --rec_model_dir="./inference/rec_ppocrv3/Student"
[2022/07/14 10:55:56] ppocr INFO: In PP-OCRv3, rec_image_shape parameter defaults to '3, 48, 320', if you are using recognition model with PP-OCRv2 or an older version, please set --rec_image_shape='3,32,320
[2022/07/14 10:55:58] ppocr INFO: Predicts of train_data/handwrite/HWDB2.0Test_images/104-P16_4.jpg:('品結構,差異化的多品牌滲透使歐萊雅確立了其在中國化妝', 0.9904912114143372)
# 視覺化文字識別圖片
from PIL import Image  
import matplotlib.pyplot as plt
import numpy as np
import os


img_path = 'train_data/handwrite/HWDB2.0Test_images/104-P16_4.jpg'

def vis(img_path):
    plt.figure()
    image = Image.open(img_path)  
    plt.imshow(image)
    plt.show()
    # image = image.resize([208, 208])  


vis(img_path)

相關文章