歡迎訪問我的GitHub
https://github.com/zq2599/blog_demos
內容:所有原創文章分類彙總及配套原始碼,涉及Java、Docker、Kubernetes、DevOPS等;
本篇概覽
- 本文是《DL4J》實戰的第二篇,前面做好了準備工作,接下來進入正式實戰,本篇內容是經典的入門例子:鳶尾花分類
- 下圖是一朵鳶尾花,我們可以測量到它的四個特徵:花瓣(petal)的寬和高,花萼(sepal)的 寬和高:
- 鳶尾花有三種:Setosa、Versicolor、Virginica
- 今天的實戰是用前饋神經網路Feed-Forward Neural Network (FFNN)就行鳶尾花分類的模型訓練和評估,在拿到150條鳶尾花的特徵和分類結果後,我們先訓練出模型,再評估模型的效果:
原始碼下載
- 本篇實戰中的完整原始碼可在GitHub下載到,地址和連結資訊如下表所示(https://github.com/zq2599/blo...):
名稱 | 連結 | 備註 |
---|---|---|
專案主頁 | https://github.com/zq2599/blo... | 該專案在GitHub上的主頁 |
git倉庫地址(https) | https://github.com/zq2599/blo... | 該專案原始碼的倉庫地址,https協議 |
git倉庫地址(ssh) | git@github.com:zq2599/blog_demos.git | 該專案原始碼的倉庫地址,ssh協議 |
- 這個git專案中有多個資料夾,《DL4J實戰》系列的原始碼在<font color="blue">dl4j-tutorials</font>資料夾下,如下圖紅框所示:
- <font color="blue">dl4j-tutorials</font>資料夾下有多個子工程,本次實戰程式碼在<font color="blue">dl4j-tutorials</font>目錄下,如下圖紅框:
編碼
- 在<font color="blue">dl4j-tutorials</font>工程下新建子工程<font color="red">classifier-iris</font>,其pom.xml如下:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>dlfj-tutorials</artifactId>
<groupId>com.bolingcavalry</groupId>
<version>1.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>classifier-iris</artifactId>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>com.bolingcavalry</groupId>
<artifactId>commons</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>${nd4j.backend}</artifactId>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</dependency>
</dependencies>
</project>
- 上述pom.xml有一處需要注意的地方,就是<font color="blue">${nd4j.backend}</font>引數的值,該值在決定了後端線性代數計算是用CPU還是GPU,本篇為了簡化操作選擇了CPU(因為個人的顯示卡不同,程式碼裡無法統一),對應的配置就是<font color="red">nd4j-native</font>;
- 原始碼全部在Iris.java檔案中,並且程式碼中已新增詳細註釋,就不再贅述了:
package com.bolingcavalry.classifier;
import com.bolingcavalry.commons.utils.DownloaderUtility;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
/**
* @author will (zq2599@gmail.com)
* @version 1.0
* @description: 鳶尾花訓練
* @date 2021/6/13 17:30
*/
@SuppressWarnings("DuplicatedCode")
@Slf4j
public class Iris {
public static void main(String[] args) throws Exception {
//第一階段:準備
// 跳過的行數,因為可能是表頭
int numLinesToSkip = 0;
// 分隔符
char delimiter = ',';
// CSV讀取工具
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
// 下載並解壓後,得到檔案的位置
String dataPathLocal = DownloaderUtility.IRISDATA.Download();
log.info("鳶尾花資料已下載並解壓至 : {}", dataPathLocal);
// 讀取下載後的檔案
recordReader.initialize(new FileSplit(new File(dataPathLocal,"iris.txt")));
// 每一行的內容大概是這樣的:5.1,3.5,1.4,0.2,0
// 一共五個欄位,從零開始算的話,標籤在第四個欄位
int labelIndex = 4;
// 鳶尾花一共分為三類
int numClasses = 3;
// 一共150個樣本
int batchSize = 150; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)
// 載入到資料集迭代器中
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
// 洗牌(打亂順序)
allData.shuffle();
// 設定比例,150個樣本中,百分之六十五用於訓練
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training
// 訓練用的資料集
DataSet trainingData = testAndTrain.getTrain();
// 驗證用的資料集
DataSet testData = testAndTrain.getTest();
// 指定歸一化器:獨立地將每個特徵值(和可選的標籤值)歸一化為0平均值和1的標準差。
DataNormalization normalizer = new NormalizerStandardize();
// 先擬合
normalizer.fit(trainingData);
// 對訓練集做歸一化
normalizer.transform(trainingData);
// 對測試集做歸一化
normalizer.transform(testData);
// 每個鳶尾花有四個特徵
final int numInputs = 4;
// 共有三種鳶尾花
int outputNum = 3;
// 隨機數種子
long seed = 6;
//第二階段:訓練
log.info("開始配置...");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.activation(Activation.TANH) // 啟用函式選用標準的tanh(雙曲正切)
.weightInit(WeightInit.XAVIER) // 權重初始化選用XAVIER:均值 0, 方差為 2.0/(fanIn + fanOut)的高斯分佈
.updater(new Sgd(0.1)) // 更新器,設定SGD學習速率排程器
.l2(1e-4) // L2正則化配置
.list() // 配置多層網路
.layer(new DenseLayer.Builder().nIn(numInputs).nOut(3) // 隱藏層
.build())
.layer(new DenseLayer.Builder().nIn(3).nOut(3) // 隱藏層
.build())
.layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) // 損失函式:負對數似然
.activation(Activation.SOFTMAX) // 輸出層指定啟用函式為:SOFTMAX
.nIn(3).nOut(outputNum).build())
.build();
// 模型配置
MultiLayerNetwork model = new MultiLayerNetwork(conf);
// 初始化
model.init();
// 每一百次迭代列印一次分數(損失函式的值)
model.setListeners(new ScoreIterationListener(100));
long startTime = System.currentTimeMillis();
log.info("開始訓練");
// 訓練
for(int i=0; i<1000; i++ ) {
model.fit(trainingData);
}
log.info("訓練完成,耗時[{}]ms", System.currentTimeMillis()-startTime);
// 第三階段:評估
// 在測試集上評估模型
Evaluation eval = new Evaluation(numClasses);
INDArray output = model.output(testData.getFeatures());
eval.eval(testData.getLabels(), output);
log.info("評估結果如下\n" + eval.stats());
}
}
- 編碼完成後,執行main方法,可見順利完成訓練並輸出了評估結果,還有混淆矩陣用於輔助分析:
- 至此,我們們的第一個實戰就完成了,通過經典例項體驗的DL4J訓練和評估的常規步驟,對重要API也有了初步認識,接下來會繼續實戰,接觸到更多的經典例項;
你不孤單,欣宸原創一路相伴
歡迎關注公眾號:程式設計師欣宸
微信搜尋「程式設計師欣宸」,我是欣宸,期待與您一同暢遊Java世界...
https://github.com/zq2599/blog_demos