神經網路 | 基於MATLAB 深度學習工具實現簡單的數字分類問題(卷積神經網路)

衝動的MJ發表於2019-03-07

博主github:https://github.com/MichaelBeechan   

博主CSDN:https://blog.csdn.net/u011344545

%% Time:2019.3.7
%% Name:Michael Beechan
%% Function:
%% 這個例子展示瞭如何建立和訓練一個簡單的卷積神經網路用於深度學習分類。
%% 卷積神經網路是深度學習的重要工具,尤其適用於影像識別。
%% Load and explore image data.
%% Define the network architecture.
%% Specify training options.
%% Train the network.
%% Predict the labels of new data and calculate the classification accuracy.

1、 Load and explore image data

%% 載入資料並儲存
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ...
    'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders', true, 'LabelSource', 'foldernames');

%% 隨機顯示資料
figure;
perm = randperm(10000, 20);
for i = 1 : 20
    subplot(4, 5, i);
    imshow(imds.Files{perm(i)});
end

 %% 計算每個類別中的影像數量
labelCount = countEachLabel(imds)

labelCount =

  10×2 table

    Label    Count
    _____    _____

      0      1000 
      1      1000 
      2      1000 
      3      1000 
      4      1000 
      5      1000 
      6      1000 
      7      1000 
      8      1000 
      9      1000 
%% 指定影像大小尺寸28*28*1
img = readimage(imds, 1);
size(img)

ans =

    28    28

2、Specify Training and Validation Sets

%% 劃分資料為訓練集合驗證集,訓練集中每個類別包含750張影像,驗證集包含其餘影像的標籤
numTrainFiles = 750;
[imdsTrain, imdsValidation] = splitEachLabel(imds, numTrainFiles, 'randomize');

3、Define Network Architecture

%% 定義CNN框架
layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3, 8, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2, 'Stride', 2)
    
    convolution2dLayer(3, 16, 'padding', 'same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2, 'Stride', 2)
    
    convolution2dLayer(3, 32, 'Padding', 'same')
    batchNormalizationLayer
    reluLayer
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

4、Specify Training Options

%% 指定訓練Options——SGDM,學習率0.01,最大epoch=4
options = trainingOptions('sgdm', ...
    'InitialLearnRate', 0.01, ...
    'MaxEpochs', 4, ...
    'Shuffle', 'every-epoch', ...
    'ValidationData', imdsValidation, ...
    'ValidationFrequency', 30, ...
    'Verbose', false, ...
    'Plots', 'training-progress');
 

5、使用訓練集訓練網路

net = trainNetwork(imdsTrain, layers, options);

6、對驗證影像進行分類並計算精度

%% 對驗證影像進行分類並計算精度
YPred = classify(net, imdsValidation);
YValidation = imdsValidation.Labels;

accuracy = sum(YPred == YValidation) / numel(YValidation)

accuracy =  0.9964

Okay,文章就寫在這兒了,好好消化一下吧!!!!!加油!!!!

原始碼下載:https://download.csdn.net/download/u011344545/11008033

相關文章