TensorFlow.NET機器學習入門【5】採用神經網路實現手寫數字識別(MNIST)

seabluescn發表於2021-12-28

 從這篇文章開始,終於要乾點正兒八經的工作了,前面都是準備工作。這次我們要解決機器學習的經典問題,MNIST手寫數字識別。

首先介紹一下資料集。請首先解壓:TF_Net\Asset\mnist_png.tar.gz檔案

 資料夾內包括兩個資料夾:training和validation,其中training資料夾下包括60000個訓練圖片validation下包括10000個評估圖片,圖片為28*28畫素,分別放在0~9十個資料夾中。

程式總體流程和上一篇文章介紹的BMI分析程式基本一致,畢竟都是多元分類,有幾點不一樣。

1、BMI程式的特徵資料(輸入)為一維陣列,包含兩個數字,MNIST的特徵資料為28*28的二位陣列;

2、BMI程式的輸出為3個,MNIST的輸出為10個;

 

網路模型構建如下:

        private readonly int img_rows = 28;
        private readonly int img_cols = 28;
        private readonly int num_classes = 10;  // total classes
        /// <summary>
        /// 構建網路模型
        /// </summary>     
        private Model BuildModel()
        {
            // 網路引數          
            int n_hidden_1 = 128;    // 1st layer number of neurons.     
            int n_hidden_2 = 128;    // 2nd layer number of neurons.                                
            float scale = 1.0f / 255;

            var model = keras.Sequential(new List<ILayer>
            {
                keras.layers.InputLayer((img_rows,img_cols)),
                keras.layers.Flatten(),
                keras.layers.Rescaling(scale),
                keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
                keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
                keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
            });

            return model;
        }

這個網路裡用到了兩個新方法,需要解釋一下:

1、Flatten方法:這裡表示拉平,把28*28的二維陣列拉平為含784個資料的一維陣列,因為二維陣列無法進行運算;

2、Rescaling 方法:就是對每個資料乘以一個係數,因為我們從圖片獲取的資料為每一個位點的灰度值,其取值範圍為0~255,所以乘以一個係數將資料縮小到1以內,以免後面運算時溢位。

 

其它基本和上一篇文章介紹的差不多,全部程式碼如下:

TensorFlow.NET機器學習入門【5】採用神經網路實現手寫數字識別(MNIST)
    /// <summary>
    /// 通過神經網路來實現多元分類
    /// </summary>
    public class NN_MultipleClassification_BMI
    {
        private readonly Random random = new Random(1);

        // 網路引數
        int num_features = 2; // data features       
        int num_classes = 3;  // total output .

        public void Run()
        {
            var model = BuildModel();
            model.summary();          

            Console.WriteLine("Press any key to continue...");
            Console.ReadKey();

            (NDArray train_x, NDArray train_y) = PrepareData(1000);
            model.compile(optimizer: keras.optimizers.Adam(0.001f),
              loss: keras.losses.SparseCategoricalCrossentropy(),
              metrics: new[] { "accuracy" });
            model.fit(train_x, train_y, batch_size: 128, epochs: 300);

            test(model);
        }

        /// <summary>
        /// 構建網路模型
        /// </summary>     
        private Model BuildModel()
        {
            // 網路引數          
            int n_hidden_1 = 64; // 1st layer number of neurons.     
            int n_hidden_2 = 64; // 2nd layer number of neurons.           

            var model = keras.Sequential(new List<ILayer>
            {
                keras.layers.InputLayer(num_features),
                keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu),
                keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu),
                keras.layers.Dense(num_classes, activation:keras.activations.Softmax)
            });

            return model;
        }

        /// <summary>
        /// 載入訓練資料
        /// </summary>
        /// <param name="total_size"></param>    
        private (NDArray, NDArray) PrepareData(int total_size)
        {
            float[,] arrx = new float[total_size, num_features];
            int[] arry = new int[total_size];

            for (int i = 0; i < total_size; i++)
            {
                float weight = (float)random.Next(30, 100) / 100;
                float height = (float)random.Next(140, 190) / 100;
                float bmi = (weight * 100) / (height * height);

                arrx[i, 0] = weight;
                arrx[i, 1] = height;

                switch (bmi)
                {
                    case var x when x < 18.0f:
                        arry[i] = 0;
                        break;

                    case var x when x >= 18.0f && x <= 28.0f:
                        arry[i] = 1;
                        break;

                    case var x when x > 28.0f:
                        arry[i] = 2;
                        break;
                }
            }

            return (np.array(arrx), np.array(arry));
        }

        /// <summary>
        /// 消費模型
        /// </summary>      
        private void test(Model model)
        {
            int test_size = 20;
            for (int i = 0; i < test_size; i++)
            {
                float weight = (float)random.Next(40, 90) / 100;
                float height = (float)random.Next(145, 185) / 100;
                float bmi = (weight * 100) / (height * height);

                var test_x = np.array(new float[1, 2] { { weight, height } });
                var pred_y = model.Apply(test_x);

                Console.WriteLine($"{i}:weight={(float)weight} \theight={height} \tBMI={bmi:0.0} \tPred:{pred_y[0].numpy()}");
            }
        }
    }
View Code

另有兩點說明:

1、由於對圖片的讀取比較耗時,所以我採用了一個方法,就是把讀取到的資料序列化到一個二進位制檔案中,下次直接從二進位制檔案反序列化即可,大大加快處理速度。

2、我沒有采用validation圖片進行評估,只是簡單選了20個樣本測試了一下。

 

【相關資源】

 原始碼:Git: https://gitee.com/seabluescn/tf_not.git

專案名稱:NN_MultipleClassification_MNIST

目錄:檢視TensorFlow.NET機器學習入門系列目錄

相關文章