百度飛槳(PaddlePaddle)-數字識別

VipSoft發表於2023-05-09

手寫數字識別任務 用於對 0 ~ 9 的十類數字進行分類,即輸入手寫數字的圖片,可識別出這個圖片中的數字。

使用 pip 工具安裝 matplotlib 和 numpy

python -m pip install matplotlib numpy -i https://mirror.baidu.com/pypi/simple
python -m pip install paddlepaddle==2.4.2 -i https://pypi.tuna.tsinghua.edu.cn/simple

D:\OpenSource\PaddlePaddle>python -m pip install matplotlib numpy -i https://mirror.baidu.com/pypi/simple
Looking in indexes: https://mirror.baidu.com/pypi/simple
Collecting matplotlib
  Downloading https://mirror.baidu.com/pypi/packages/92/01/2c04d328db6955d77f8f60c17068dde8aa66f153b2c599ca03c2cb0d5567/matplotlib-3.7.1-cp38-cp38-win_amd64.whl (7.6 MB)
     |████████████████████████████████| 7.6 MB ...
Requirement already satisfied: numpy in d:\program files\python38\lib\site-packages (1.24.3)
Collecting packaging>=20.0
  Downloading https://mirror.baidu.com/pypi/packages/ab/c3/57f0601a2d4fe15de7a553c00adbc901425661bf048f2a22dfc500caf121/packaging-23.1-py3-none-any.whl (48 kB)
     |████████████████████████████████| 48 kB 1.2 MB/s
Collecting cycler>=0.10
  Downloading https://mirror.baidu.com/pypi/packages/5c/f9/695d6bedebd747e5eb0fe8fad57b72fdf25411273a39791cde838d5a8f51/cycler-0.11.0-py3-none-any.whl (6.4 kB)
Requirement already satisfied: pillow>=6.2.0 in d:\program files\python38\lib\site-packages (from matplotlib) (9.5.0)
Collecting python-dateutil>=2.7
  Downloading https://mirror.baidu.com/pypi/packages/36/7a/87837f39d0296e723bb9b62bbb257d0355c7f6128853c78955f57342a56d/python_dateutil-2.8.2-py2.py3-none-any.whl (247 kB)
     |████████████████████████████████| 247 kB ...
Collecting importlib-resources>=3.2.0
  Downloading https://mirror.baidu.com/pypi/packages/38/71/c13ea695a4393639830bf96baea956538ba7a9d06fcce7cef10bfff20f72/importlib_resources-5.12.0-py3-none-any.whl (36 kB)
Collecting fonttools>=4.22.0
  Downloading https://mirror.baidu.com/pypi/packages/16/07/1c7547e27f559ec078801d522cc4d5127cdd4ef8e831c8ddcd9584668a07/fonttools-4.39.3-py3-none-any.whl (1.0 MB)
     |████████████████████████████████| 1.0 MB ...
Collecting pyparsing>=2.3.1
  Downloading https://mirror.baidu.com/pypi/packages/6c/10/a7d0fa5baea8fe7b50f448ab742f26f52b80bfca85ac2be9d35cdd9a3246/pyparsing-3.0.9-py3-none-any.whl (98 kB)
     |████████████████████████████████| 98 kB 862 kB/s
Collecting contourpy>=1.0.1
  Downloading https://mirror.baidu.com/pypi/packages/08/ce/9bfe9f028cb5a8ee97898da52f4905e0e2d9ca8203ffdcdbe80e1769b549/contourpy-1.0.7-cp38-cp38-win_amd64.whl (162 kB)
     |████████████████████████████████| 162 kB ...
Collecting kiwisolver>=1.0.1
  Downloading https://mirror.baidu.com/pypi/packages/4f/05/59b34e788bf2b45c7157c3d898d567d28bc42986c1b6772fb1af329eea0d/kiwisolver-1.4.4-cp38-cp38-win_amd64.whl (55 kB)
     |████████████████████████████████| 55 kB 784 kB/s
Collecting zipp>=3.1.0
  Downloading https://mirror.baidu.com/pypi/packages/5b/fa/c9e82bbe1af6266adf08afb563905eb87cab83fde00a0a08963510621047/zipp-3.15.0-py3-none-any.whl (6.8 kB)
Requirement already satisfied: six>=1.5 in d:\program files\python38\lib\site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Installing collected packages: zipp, python-dateutil, pyparsing, packaging, kiwisolver, importlib-resources, fonttools, cycler, contourpy, matplotlib
Successfully installed contourpy-1.0.7 cycler-0.11.0 fonttools-4.39.3 importlib-resources-5.12.0 kiwisolver-1.4.4 matplotlib-3.7.1 packaging-23.1 pyparsing-3.0.9 python-dateutil-2.8.2 zipp-3.15.0
WARNING: You are using pip version 21.1.1; however, version 23.1.2 is available.
You should consider upgrading via the 'D:\Program Files\Python38\python.exe -m pip install --upgrade pip' command.

D:\OpenSource\PaddlePaddle>

建立 DigitalRecognition.py

官網程式碼少了 plt.show() # 要加上這句,才會顯示圖片

import paddle
import numpy as np
from paddle.vision.transforms import Normalize

transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
# 下載資料集並初始化 DataSet
'''
飛槳在 paddle.vision.datasets 下內建了計算機視覺(Computer Vision,CV)領域常見的資料集,
如 MNIST、Cifar10、Cifar100、FashionMNIST 和 VOC2012 等。在本任務中,
先後載入了 MNIST 訓練集(mode='train')和測試集(mode='test'),訓練集用於訓練模型,測試集用於評估模型效果。
'''
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
# 列印資料集裡圖片數量 60000 images in train_dataset, 10000 images in test_dataset
# print('{} images in train_dataset, {} images in test_dataset'.format(len(train_dataset), len(test_dataset)))

# 模型組網並初始化網路
lenet = paddle.vision.models.LeNet(num_classes=10)
model = paddle.Model(lenet)

# 模型訓練的配置準備,準備損失函式,最佳化器和評價指標
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
              paddle.nn.CrossEntropyLoss(),
              paddle.metric.Accuracy())

# 模型訓練
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
# 模型評估
model.evaluate(test_dataset, batch_size=64, verbose=1)

# 儲存模型
model.save('./output/mnist')
# 載入模型
model.load('output/mnist')

# 從測試集中取出一張圖片
img, label = test_dataset[0]
# 將圖片shape從1*28*28變為1*1*28*28,增加一個batch維度,以匹配模型輸入格式要求
img_batch = np.expand_dims(img.astype('float32'), axis=0)

# 執行推理並列印結果,此處predict_batch返回的是一個list,取出其中資料獲得預測結果
out = model.predict_batch(img_batch)[0]
pred_label = out.argmax()
print('true label: {}, pred label: {}'.format(label[0], pred_label))
# 視覺化圖片
from matplotlib import pyplot as plt
plt.imshow(img[0])
plt.show()  # 要加上這句,才會顯示圖片

PyCharm執行(推薦,有錯誤能顯示出來)

Python MatplotlibDeprecationWarning Matplotlib 3.6 and will be removed two minor releases later
File -> Settings -> Tools -> Python Scientific -> 取消 Show plots in tool window,
取消後,將不會看到紅字警告提示
image

CMD 執行

D:\OpenSource\PaddlePaddle>python DigitalRecognition.py
image
image

如果碰到下列錯誤,需要加上 plt.show()
Python MatplotlibDeprecationWarning Matplotlib 3.6 and will be removed two minor releases later

MatplotlibDeprecationWarning: Support for FigureCanvases without a required_interactive_framework attribute was deprecated in Matplotlib 3.6 and will be removed two minor releases later.
  plt.imshow(img[0])

資料集定義與載入

飛槳在 paddle.vision.datasets 下內建了計算機視覺(Computer Vision,CV)領域常見的資料集,如 MNIST、Cifar10、Cifar100、FashionMNIST 和 VOC2012 等。在本任務中,先後載入了 MNIST 訓練集(mode='train')和測試集(mode='test'),訓練集用於訓練模型,測試集用於評估模型效果。
飛槳除了內建了 CV 領域常見的資料集,還在 paddle.text 下內建了自然語言處理(Natural Language Processing,NLP)領域常見的資料集,並提供了自定義資料集與載入功能的 paddle.io.Dataset 和 paddle.io.DataLoader API,詳細使用方法可參考『資料集定義與載入』 章節。

另外在 paddle.vision.transforms 下提供了一些常用的影像變換操作,如對影像的翻轉、裁剪、調整亮度等處理,可實現資料增強,以增加訓練樣本的多樣性,提升模型的泛化能力。本任務在初始化 MNIST 資料集時透過 transform 欄位傳入了 Normalize 變換對影像進行歸一化,對影像進行歸一化可以加快模型訓練的收斂速度。該功能的具體使用方法可參考『資料預處理』 章節。

模型組網

https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/quick_start_cn.html#moxingzuwang

模型訓練與評估

https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/quick_start_cn.html#moxingxunlianyupinggu

模型推理

https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/quick_start_cn.html#moxingtuili

image

相關文章