DL4J實戰之三:經典卷積例項(LeNet-5)

程式設計師欣宸發表於2021-10-14

歡迎訪問我的GitHub

https://github.com/zq2599/blog_demos

內容:所有原創文章分類彙總及配套原始碼,涉及Java、Docker、Kubernetes、DevOPS等;

本篇概覽

  • 作為《DL4J》實戰的第三篇,目標是在DL4J框架下建立經典的LeNet-5卷積神經網路模型,對MNIST資料集進行訓練和測試,本篇由以下內容構成:
  1. LeNet-5簡介
  2. MNIST簡介
  3. 資料集簡介
  4. 關於版本和環境
  5. 編碼
  6. 驗證

LeNet-5簡介

  • 是Yann LeCun於1998年設計的卷積神經網路,用於手寫數字識別,例如當年美國很多銀行用其識別支票上的手寫數字,LeNet-5是早期卷積神經網路最有代表性的實驗系統之一
  • LeNet-5網路結構如下圖所示,一共七層:C1 -> S2 -> C3 -> S4 -> C5 -> F6 -> OUTPUT

在這裡插入圖片描述

在這裡插入圖片描述

  • 按照上圖簡單分析一下,用於指導接下來的開發:
  1. 每張圖片都是28*28的單通道,矩陣應該是[1, 28,28]
  2. C1是卷積層,所用卷積核尺寸5*5,滑動步長1,卷積核數目20,所以尺寸變化是:28-5+1=24(想象為寬度為5的視窗在寬度為28的視窗內滑動,能滑多少次),輸出矩陣是[20,24,24]
  3. S2是池化層,核尺寸2*2,步長2,型別是MAX,池化操作後尺寸減半,變成了[20,12,12]
  4. C3是卷積層,所用卷積核尺寸5*5,滑動步長1,卷積核數目50,所以尺寸變化是:12-5+1=8,輸出矩陣[50,8,8]
  5. S4是池化層,核尺寸2*2,步長2,型別是MAX,池化操作後尺寸減半,變成了[50,4,4]
  6. C5是全連線層(FC),神經元數目500,接relu啟用函式
  7. 最後是全連線層Output,共10個節點,代表數字0到9,啟用函式是softmax

MNIST簡介

  • MNIST是經典的計算機視覺資料集,來源是National Institute of Standards and Technology (NIST,美國國家標準與技術研究所),包含各種手寫數字圖片,其中訓練集60,000張,測試集 10,000張,
  • MNIST來源於250 個不同人的手寫,其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員.,測試集(test set) 也是同樣比例的手寫數字資料
  • MNIST官網:http://yann.lecun.com/exdb/mnist/

資料集簡介

  • 從MNIST官網下載的原始資料並非圖片檔案,需要按官方給出的格式說明做解析處理才能轉為一張張圖片,這些事情顯然不是本篇的主題,因此我們們可以直接使用DL4J為我們準備好的資料集(下載地址稍後給出),該資料集中是一張張獨立的圖片,這些圖片所在目錄的名字就是該圖片具體的數字,如下圖,目錄0裡面全是數字0的圖片:

在這裡插入圖片描述

  • 上述資料集的下載地址有兩個:
  1. 可以在CSDN下載(0積分):https://download.csdn.net/download/boling_cavalry/19846603
  2. github:https://raw.githubusercontent.com/zq2599/blog_download_files/master/files/mnist_png.tar.gz
  • 下載之後解壓開,是個名為mnist_png的資料夾,稍後的實戰中我們們會用到它

關於DL4J版本

  • 《DL4J實戰》系列的原始碼採用了maven的父子工程結構,DL4J的版本在父工程dlfj-tutorials中定義為1.0.0-beta7
  • 本篇的程式碼雖然還是dlfj-tutorials的子工程,但是DL4J版本卻使用了更低的1.0.0-beta6,之所以這麼做,是因為下一篇文章,我們們會把本篇的訓練和測試工作交給GPU來完成,而對應的CUDA庫只有1.0.0-beta6
  • 扯了這麼多,可以開始編碼了

原始碼下載

名稱 連結 備註
專案主頁 https://github.com/zq2599/blog_demos 該專案在GitHub上的主頁
git倉庫地址(https) https://github.com/zq2599/blog_demos.git 該專案原始碼的倉庫地址,https協議
git倉庫地址(ssh) git@github.com:zq2599/blog_demos.git 該專案原始碼的倉庫地址,ssh協議
  • 這個git專案中有多個資料夾,《DL4J實戰》系列的原始碼在dl4j-tutorials資料夾下,如下圖紅框所示:

在這裡插入圖片描述

  • dl4j-tutorials資料夾下有多個子工程,本次實戰程式碼在simple-convolution目錄下,如下圖紅框:

在這裡插入圖片描述

編碼

  • 在父工程 dl4j-tutorials下新建名為 simple-convolution的子工程,其pom.xml如下,可見這裡的dl4j版本被指定為1.0.0-beta6
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <artifactId>dlfj-tutorials</artifactId>
        <groupId>com.bolingcavalry</groupId>
        <version>1.0-SNAPSHOT</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>simple-convolution</artifactId>

    <properties>
        <dl4j-master.version>1.0.0-beta6</dl4j-master.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>

        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
        </dependency>

        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>

        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>${nd4j.backend}</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
    </dependencies>
</project>
  • 接下來按照前面的分析實現程式碼,已經新增了詳細註釋,就不再贅述了:
package com.bolingcavalry.convolution;

import lombok.extern.slf4j.Slf4j;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

@Slf4j
public class LeNetMNISTReLu {

    // 存放檔案的地址,請酌情修改
//    private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist";

    private static final String BASE_PATH = "E:\\temp\\202106\\26";

    public static void main(String[] args) throws Exception {
        // 圖片畫素高
        int height = 28;
        // 圖片畫素寬
        int width = 28;
        // 因為是黑白影像,所以顏色通道只有一個
        int channels = 1;
        // 分類結果,0-9,共十種數字
        int outputNum = 10;
        // 批大小
        int batchSize = 54;
        // 迴圈次數
        int nEpochs = 1;
        // 初始化偽隨機數的種子
        int seed = 1234;

        // 隨機數工具
        Random randNumGen = new Random(seed);
        
        log.info("檢查資料集資料夾是否存在:{}", BASE_PATH + "/mnist_png");

        if (!new File(BASE_PATH + "/mnist_png").exists()) {
            log.info("資料集檔案不存在,請下載壓縮包並解壓到:{}", BASE_PATH);
            return;
        }

        // 標籤生成器,將指定檔案的父目錄作為標籤
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        // 歸一化配置(畫素值從0-255變為0-1)
        DataNormalization imageScaler = new ImagePreProcessingScaler();

        // 不論訓練集還是測試集,初始化操作都是相同套路:
        // 1. 讀取圖片,資料格式為NCHW
        // 2. 根據批大小建立的迭代器
        // 3. 將歸一化器作為前處理器

        log.info("訓練集的向量化操作...");
        // 初始化訓練集
        File trainData = new File(BASE_PATH + "/mnist_png/training");
        FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
        trainRR.initialize(trainSplit);
        DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);
        // 擬合資料(實現類中實際上什麼也沒做)
        imageScaler.fit(trainIter);
        trainIter.setPreProcessor(imageScaler);

        log.info("測試集的向量化操作...");
        // 初始化測試集,與前面的訓練集操作類似
        File testData = new File(BASE_PATH + "/mnist_png/testing");
        FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
        ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
        testRR.initialize(testSplit);
        DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
        testIter.setPreProcessor(imageScaler); // same normalization for better results

        log.info("配置神經網路");

        // 在訓練中,將學習率配置為隨著迭代階梯性下降
        Map<Integer, Double> learningRateSchedule = new HashMap<>();
        learningRateSchedule.put(0, 0.06);
        learningRateSchedule.put(200, 0.05);
        learningRateSchedule.put(600, 0.028);
        learningRateSchedule.put(800, 0.0060);
        learningRateSchedule.put(1000, 0.001);

        // 超引數
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            // L2正則化係數
            .l2(0.0005)
            // 梯度下降的學習率設定
            .updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule)))
            // 權重初始化
            .weightInit(WeightInit.XAVIER)
            // 準備分層
            .list()
            // 卷積層
            .layer(new ConvolutionLayer.Builder(5, 5)
                .nIn(channels)
                .stride(1, 1)
                .nOut(20)
                .activation(Activation.IDENTITY)
                .build())
            // 下采樣,即池化
            .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            // 卷積層
            .layer(new ConvolutionLayer.Builder(5, 5)
                .stride(1, 1) // nIn need not specified in later layers
                .nOut(50)
                .activation(Activation.IDENTITY)
                .build())
            // 下采樣,即池化
            .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                .kernelSize(2, 2)
                .stride(2, 2)
                .build())
            // 稠密層,即全連線
            .layer(new DenseLayer.Builder().activation(Activation.RELU)
                .nOut(500)
                .build())
            // 輸出
            .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                .nOut(outputNum)
                .activation(Activation.SOFTMAX)
                .build())
            .setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
            .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        // 每十個迭代列印一次損失函式值
        net.setListeners(new ScoreIterationListener(10));

        log.info("神經網路共[{}]個引數", net.numParams());

        long startTime = System.currentTimeMillis();
        // 迴圈操作
        for (int i = 0; i < nEpochs; i++) {
            log.info("第[{}]個迴圈", i);
            net.fit(trainIter);
            Evaluation eval = net.evaluate(testIter);
            log.info(eval.stats());
            trainIter.reset();
            testIter.reset();
        }
        log.info("完成訓練和測試,耗時[{}]毫秒", System.currentTimeMillis()-startTime);

        // 儲存模型
        File ministModelPath = new File(BASE_PATH + "/minist-model.zip");
        ModelSerializer.writeModel(net, ministModelPath, true);
        log.info("最新的MINIST模型儲存在[{}]", ministModelPath.getPath());
    }
}
  • 執行上述程式碼,日誌輸出如下,訓練和測試都順利完成,準確率達到0.9886:
21:19:15.355 [main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1110 is 0.18300625613640034
21:19:15.365 [main] DEBUG org.nd4j.linalg.dataset.AsyncDataSetIterator - Manually destroying ADSI workspace
21:19:16.632 [main] DEBUG org.nd4j.linalg.dataset.AsyncDataSetIterator - Manually destroying ADSI workspace
21:19:16.642 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 

========================Evaluation Metrics========================
 # of classes:    10
 Accuracy:        0.9886
 Precision:       0.9885
 Recall:          0.9886
 F1 Score:        0.9885
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)


=========================Confusion Matrix=========================
    0    1    2    3    4    5    6    7    8    9
---------------------------------------------------
  972    0    0    0    0    0    2    2    2    2 | 0 = 0
    0 1126    0    3    0    2    1    1    2    0 | 1 = 1
    1    1 1019    2    0    0    0    6    3    0 | 2 = 2
    0    0    1 1002    0    5    0    1    1    0 | 3 = 3
    0    0    2    0  971    0    3    2    1    3 | 4 = 4
    0    0    0    3    0  886    2    1    0    0 | 5 = 5
    6    2    0    1    1    5  942    0    1    0 | 6 = 6
    0    1    6    0    0    0    0 1015    1    5 | 7 = 7
    1    0    1    1    0    2    0    2  962    5 | 8 = 8
    1    2    1    3    5    3    0    2    1  991 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
21:19:16.643 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 完成訓練和測試,耗時[27467]毫秒
21:19:17.019 [main] INFO com.bolingcavalry.convolution.LeNetMNISTReLu - 最新的MINIST模型儲存在[E:\temp\202106\26\minist-model.zip]

Process finished with exit code 0

關於準確率

  • 前面的測試結果顯示準確率為0.9886,這是1.0.0-beta6版本DL4J的訓練結果,如果換成1.0.0-beta7,準確率可以達到0.99以上,您可以嘗試一下;

  • 至此,DL4J框架下的經典卷積實戰就完成了,截止目前,我們們的訓練和測試工作都是CPU完成的,工作中CPU使用率的上升十分明顯,下一篇文章,我們們把今天的工作交給GPU執行試試,看能否藉助CUDA加速訓練和測試工作;

你不孤單,欣宸原創一路相伴

  1. Java系列
  2. Spring系列
  3. Docker系列
  4. kubernetes系列
  5. 資料庫+中介軟體系列
  6. DevOps系列

歡迎關注公眾號:程式設計師欣宸

微信搜尋「程式設計師欣宸」,我是欣宸,期待與您一同暢遊Java世界...
https://github.com/zq2599/blog_demos

相關文章