TensorFlow.NET機器學習入門【7】採用卷積神經網路(CNN)處理Fashion-MNIST

seabluescn發表於2021-12-30

本文將介紹如何採用卷積神經網路(CNN)來處理Fashion-MNIST資料集。

程式流程如下:

1、準備樣本資料

2、構建卷積神經網路模型

3、網路學習(訓練)

4、消費、測試

 除了網路模型的構建,其它步驟都和前面介紹的普通神經網路的處理完全一致,本文就不重複介紹了,重點講一下模型的構建。

 

先看程式碼:

        /// <summary>
        /// 構建網路模型
        /// </summary>     
        private Model BuildModel()
        {
            // 網路引數                                      
            float scale = 1.0f / 255;

            var model = keras.Sequential(new List<ILayer>
            {
                keras.layers.Rescaling(scale, input_shape: (img_rows, img_cols, channel)),

                keras.layers.Conv2D(32, 5, padding: "same", activation: keras.activations.Relu),
                keras.layers.MaxPooling2D(),

                keras.layers.Conv2D(64, 3, padding: "same", activation: keras.activations.Relu),
                keras.layers.MaxPooling2D(),

                keras.layers.Flatten(),
                keras.layers.Dense(128, activation: keras.activations.Relu),
                keras.layers.Dense(num_classes,activation:keras.activations.Softmax)
            });

            return model;
        }

keras.layers.Conv2D方法建立一個卷積層

keras.layers.MaxPooling2D方法建立一個池化層
 

卷積層的含義:

     

如上圖所示,原始資料尺寸為5*5,卷積核大小為3*3,當卷積核滑過原始圖片時,卷積核和圖片對應的資料進行運算(先乘後加),並形成新的資料。

示例的卷積核為[[1,0,1],[0,1,0],[1,0,1]],和左上角資料卷積後結果為4,填寫到對應位置。對整改圖片全部滑動一遍,即形成最終結果。

  

 採用卷積神經網路,相對於前面介紹的普通神經網路有什麼優勢呢?

1、首先,影像本身是一個二維資料,普通網路首先要把資料拉平,這一點就不合理,而卷積網路通過卷積核處理資料,保留了原始資料的基本特徵;

2、其次,採用卷積網路大大減小了引數的數量。假設原始圖片解析度為100*100,拉平後長度為10000,後面跟一個全連線層,輸出為128,此時引數量為(10000+1)*128,超過128萬。這才一個全連線層。如果採用CNN,引數數量取決於卷積核的大小和數量。假設卷積核大小為5*5,數量為32,此時引數數量為:(5*5+1)*32=832。【計算方法下面會詳細介紹】

  

 池化層的含義:

 池化就是壓縮,就是圖片資料太大了,通過池化把解析度減小一些。

 池化有均值池化和最大值池化方法,這個很好理解,就是一推資料中取平均值或最大值。MaxPooling2D明顯是最大池化法。

 

我們再看一下這個程式碼:

 keras.layers.Conv2D(32, 5, padding: "same", activation: keras.activations.Relu),

 32表示卷積核數量為32,卷積核大小為5*5,padding: "same"表示對影像進行邊緣補零,不然卷積後的影像尺寸會變小,補零後影像尺寸不變。

整體模型摘要資訊如下:

  下面逐行解釋一下:

1、首先輸入層的資料Shape為:(28,28,1),28表示圖片畫素,1表示灰度圖片,如果是彩色圖片,應該為(28,28,3)

2、Rescaling對資料進行處理,統一乘以一個係數,這裡沒有需要訓練的引數

3、引入一個卷積層,卷積核數量為32,卷積核大小為5*5(圖上看不出來),此時引數數量為:(5*5+1)*32=832,這裡卷積核尺寸為5*5,所以有25個引數,這很好理解,+1是因為作為卷積計算後還要加一個偏置b,所以每個卷積核共26個引數。由於有32個卷積核,要對同一個影像採用不同的卷積核做32次計算,所以這一層輸出資料為(28,28,32)

4、池化層將資料從(28,28,32)壓縮到(14,14,32)

5、再引入一個卷積層,卷積核數量為64,卷積核大小為3*3(圖上看不出來),這次計算和第一次不太一樣:由於上一層資料共有32片,對每一片資料採用的卷積核是不一樣的,所以這裡實際一共有32*9=288個卷積核。首先用32個卷積核和上述32片資料分別進行卷積形成32片資料,然後將32片資料疊加求和,最後再加一個偏置形成一片新資料,重複進行64次,形成64片新資料。此時引數數量為:(288+1)*64=18496

  【注意:這裡的演算法其實是和第一層卷積演算法完全一樣的,只是第一層輸入為灰度圖片,資料只有一片,如果輸入為彩色圖片,就一致了。】

6、池化層將資料從(14,14,64)壓縮到(7,7,64)

7、將資料拉平,拉平後的資料長度為:7*7*64=3136

8、引入全連線層,輸出神經元數量為128,此時引數數量為:(3136+1)*128=401536

9、最後為全連線層輸出,輸出神經元數量為10,引數數量為:(128+1)*10=1290

 

現在,由於引數數量已經很多了,訓練需要的時間也比較長了,所以需要把訓練完成後的引數儲存下來,下次可以重新載入儲存的引數接著訓練,不用從頭再來。

儲存的模型也可以釋出到生產系統用於實際的消費。 

全部程式碼如下:

TensorFlow.NET機器學習入門【7】採用卷積神經網路(CNN)處理Fashion-MNIST
    /// <summary>
    /// 採用卷積神經網路處理Fashion-MNIST資料集
    /// </summary>
    public class CNN_Fashion_MNIST
    {
        private readonly string TrainImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train";
        private readonly string TestImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\test";
        private readonly string train_date_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\cnn_train_data.bin";
        private readonly string train_label_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\cnn_train_label.bin";
        private readonly string ModelFile = @"D:\Study\Blogs\TF_Net\Model\cnn_fashion_mnist.h5";

        private readonly int img_rows = 28;
        private readonly int img_cols = 28;
        private readonly int channel = 1;
        private readonly int num_classes = 10;  // total classes

        public void Run()
        {
            var model = BuildModel();
            model.summary();
            model.load_weights(ModelFile);

            Console.WriteLine("press any key");
            Console.ReadKey();

            model.compile(optimizer: keras.optimizers.Adam(0.0001f),
                loss: keras.losses.SparseCategoricalCrossentropy(),
                metrics: new[] { "accuracy" });

            (NDArray train_x, NDArray train_y) = LoadTrainingData();
            model.fit(train_x, train_y, batch_size: 512, epochs: 1);
            model.save_weights(ModelFile);

            test(model);
        }

        /// <summary>
        /// 構建網路模型
        /// </summary>     
        private Model BuildModel()
        {
            // 網路引數                                      
            float scale = 1.0f / 255;

            var model = keras.Sequential(new List<ILayer>
            {
                keras.layers.Rescaling(scale, input_shape: (img_rows, img_cols, channel)),

                keras.layers.Conv2D(32, 5, padding: "same", activation: keras.activations.Relu),
                keras.layers.MaxPooling2D(),

                keras.layers.Conv2D(64, 3, padding: "same", activation: keras.activations.Relu),
                keras.layers.MaxPooling2D(),

                keras.layers.Flatten(),
                keras.layers.Dense(128, activation: keras.activations.Relu),
                keras.layers.Dense(num_classes,activation:keras.activations.Softmax)
            });

            return model;
        }

        /// <summary>
        /// 載入訓練資料
        /// </summary>
        /// <param name="total_size"></param>    
        private (NDArray, NDArray) LoadTrainingData()
        {
            try
            {
                Console.WriteLine("Load data");
                IFormatter serializer = new BinaryFormatter();
                FileStream loadFile = new FileStream(train_date_path, FileMode.Open, FileAccess.Read);
                float[,,,] arrx = serializer.Deserialize(loadFile) as float[,,,];

                loadFile = new FileStream(train_label_path, FileMode.Open, FileAccess.Read);
                int[] arry = serializer.Deserialize(loadFile) as int[];
                Console.WriteLine("Load data success");
                return (np.array(arrx), np.array(arry));
            }
            catch (Exception ex)
            {
                Console.WriteLine($"Load data Exception:{ex.Message}");
                return LoadRawData();
            }
        }

        private (NDArray, NDArray) LoadRawData()
        {
            Console.WriteLine("LoadRawData");

            int total_size = 60000;
            float[,,,] arrx = new float[total_size, img_rows, img_cols, channel];
            int[] arry = new int[total_size];

            int count = 0;

            DirectoryInfo RootDir = new DirectoryInfo(TrainImagePath);
            foreach (var Dir in RootDir.GetDirectories())
            {
                foreach (var file in Dir.GetFiles("*.png"))
                {
                    Bitmap bmp = (Bitmap)Image.FromFile(file.FullName);
                    if (bmp.Width != img_cols || bmp.Height != img_rows)
                    {
                        continue;
                    }

                    for (int row = 0; row < img_rows; row++)
                        for (int col = 0; col < img_cols; col++)
                        {
                            var pixel = bmp.GetPixel(col, row);
                            int val = (pixel.R + pixel.G + pixel.B) / 3;

                            arrx[count, row, col, 0] = val;
                            arry[count] = int.Parse(Dir.Name);
                        }

                    count++;
                }

                Console.WriteLine($"Load image data count={count}");
            }

            Console.WriteLine("LoadRawData finished");
            //Save Data
            Console.WriteLine("Save data");
            IFormatter serializer = new BinaryFormatter();

            //開始序列化
            FileStream saveFile = new FileStream(train_date_path, FileMode.Create, FileAccess.Write);
            serializer.Serialize(saveFile, arrx);
            saveFile.Close();

            saveFile = new FileStream(train_label_path, FileMode.Create, FileAccess.Write);
            serializer.Serialize(saveFile, arry);
            saveFile.Close();
            Console.WriteLine("Save data finished");

            return (np.array(arrx), np.array(arry));
        }

        /// <summary>
        /// 消費模型
        /// </summary>      
        private void test(Model model)
        {
            Random rand = new Random(1);

            DirectoryInfo TestDir = new DirectoryInfo(TestImagePath);
            foreach (var ChildDir in TestDir.GetDirectories())
            {
                Console.WriteLine($"Folder:【{ChildDir.Name}】");
                var Files = ChildDir.GetFiles("*.png");
                for (int i = 0; i < 10; i++)
                {
                    int index = rand.Next(1000);
                    var image = Files[index];

                    var x = LoadImage(image.FullName);
                    var pred_y = model.Apply(x);
                    var result = argmax(pred_y[0].numpy());

                    Console.WriteLine($"FileName:{image.Name}\tPred:{result}");
                }
            }
        }

        private NDArray LoadImage(string filename)
        {
            float[,,,] arrx = new float[1, img_rows, img_cols, channel];
            Bitmap bmp = (Bitmap)Image.FromFile(filename);

            for (int row = 0; row < img_rows; row++)
                for (int col = 0; col < img_cols; col++)
                {
                    var pixel = bmp.GetPixel(col, row);
                    int val = (pixel.R + pixel.G + pixel.B) / 3;
                    arrx[0, row, col, 0] = val;
                }

            return np.array(arrx);
        }

        private int argmax(NDArray array)
        {
            var arr = array.reshape(-1);

            float max = 0;
            for (int i = 0; i < 10; i++)
            {
                if (arr[i] > max)
                {
                    max = arr[i];
                }
            }

            for (int i = 0; i < 10; i++)
            {
                if (arr[i] == max)
                {
                    return i;
                }
            }

            return 0;
        }
    }
View Code

 通過採用CNN的方法,我們可以把Fashion-MNIST識別率提高到大約94%左右,而且還有提高的空間。但是網路的優化是一件非常困難的事情,特別是識別率已經很高的時候,想提高1個百分點都是很不容易的。

 以下是一個優化過的網路,我查閱了不少資料,也參考了很多程式碼,才構建了這個網路,它的識別率約為96%,再怎麼調整也提高不上去了。

        /// <summary>
        /// 構建網路模型
        /// </summary>     
        private Model BuildModel()
        {
            // 網路引數                                      
            float scale = 1.0f / 255;
            var model = keras.Sequential(new List<ILayer>
            {
                keras.layers.Rescaling(scale, input_shape: (img_rows, img_cols, channel)),

                keras.layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu),
                keras.layers.MaxPooling2D(),

                keras.layers.Conv2D(64, 3, padding: "same", activation: keras.activations.Relu),
                keras.layers.MaxPooling2D(),

                keras.layers.Dropout(0.3f),
                keras.layers.BatchNormalization(),

                keras.layers.Conv2D(128, 3, padding: "same", activation: keras.activations.Relu),
                keras.layers.Conv2D(128, 3, padding: "same", activation: keras.activations.Relu),
                keras.layers.MaxPooling2D(),

                keras.layers.Dropout(0.4f),
                keras.layers.Flatten(),               
                keras.layers.Dense(512, activation: keras.activations.Relu),
                keras.layers.Dropout(0.25f),
                keras.layers.Dense(num_classes,activation:keras.activations.Softmax)
            });

            return model;
        }

 

【參考資料】

卷積神經網路CNN總結 - Madcola - 部落格園 (cnblogs.com)

卷積神經網路(CNN)模型結構 - 劉建平Pinard - 部落格園 (cnblogs.com)

 

【相關資源】

原始碼:Git: https://gitee.com/seabluescn/tf_not.git

專案名稱:CNN_Fashion_MNIST,CNN_Fashion_MNIST_Plus

目錄:檢視TensorFlow.NET機器學習入門系列目錄

相關文章