"如果一個演算法在MNIST上不work,那麼它就根本沒法用;而如果它在MNIST上work,它在其他資料上也可能不work"。
—— 馬克吐溫
上一篇文章我們實現了一個MNIST手寫數字識別的程式,通過一個簡單的兩層神經網路,就輕鬆獲得了98%的識別成功率。這個成功率不代表你的網路是有效的,因為MNIST實在是太簡單了,我們需要更復雜的資料集來檢驗網路的有效性!這就有了Fashion-MNIST資料集,它採用10種服裝的圖片來取代數字0~9,除此之外,其圖片大小、數量均和MNIST一致。
上篇文章的程式碼幾乎不用改動,只要改個獲取原始圖片檔案的資料夾名稱即可。
程式執行結果識別成功率大約為82%左右。
我們可以對網路進行調整,看能否提高識別率,具體可用的方法:
1、增加網路層
2、增加神經元個數
3、改用其它啟用函式
試驗結果表明,不管如何調整,識別率始終上不去多少。可見該網路方案已經碰到了瓶頸,如果要大幅度提高識別率必須要採取新的方案了。
下篇文章我們將介紹卷積神經網路(CNN)的應用,通過CNN來處理影像資料將是一個更好、更科學的解決方案。
由於本文程式碼和上一篇文章的程式碼高度一致,這裡就不再詳細說明了。全部程式碼如下:
/// <summary> /// 採用神經網路處理Fashion-MNIST資料集 /// </summary> public class NN_MultipleClassification_Fashion_MNIST { private readonly string TrainImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train"; private readonly string TestImagePath = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\test"; private readonly string train_date_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train_data.bin"; private readonly string train_label_path = @"D:\Study\Blogs\TF_Net\Asset\fashion_mnist_png\train_label.bin"; private readonly int img_rows = 28; private readonly int img_cols = 28; private readonly int num_classes = 10; // total classes public void Run() { var model = BuildModel(); model.summary(); model.compile(optimizer: keras.optimizers.Adam(0.001f), loss: keras.losses.SparseCategoricalCrossentropy(), metrics: new[] { "accuracy" }); (NDArray train_x, NDArray train_y) = LoadTrainingData(); model.fit(train_x, train_y, batch_size: 1024, epochs: 20); test(model); } /// <summary> /// 構建網路模型 /// </summary> private Model BuildModel() { // 網路引數 int n_hidden_1 = 128; // 1st layer number of neurons. int n_hidden_2 = 128; // 2nd layer number of neurons. float scale = 1.0f / 255; var model = keras.Sequential(new List<ILayer> { keras.layers.InputLayer((img_rows,img_cols)), keras.layers.Flatten(), keras.layers.Rescaling(scale), keras.layers.Dense(n_hidden_1, activation:keras.activations.Relu), keras.layers.Dense(n_hidden_2, activation:keras.activations.Relu), keras.layers.Dense(num_classes, activation:keras.activations.Softmax) }); return model; } /// <summary> /// 載入訓練資料 /// </summary> /// <param name="total_size"></param> private (NDArray, NDArray) LoadTrainingData() { try { Console.WriteLine("Load data"); IFormatter serializer = new BinaryFormatter(); FileStream loadFile = new FileStream(train_date_path, FileMode.Open, FileAccess.Read); float[,,] arrx = serializer.Deserialize(loadFile) as float[,,]; loadFile = new FileStream(train_label_path, FileMode.Open, FileAccess.Read); int[] arry = serializer.Deserialize(loadFile) as int[]; Console.WriteLine("Load data success"); return (np.array(arrx), np.array(arry)); } catch (Exception ex) { Console.WriteLine($"Load data Exception:{ex.Message}"); return LoadRawData(); } } private (NDArray, NDArray) LoadRawData() { Console.WriteLine("LoadRawData"); int total_size = 60000; float[,,] arrx = new float[total_size, img_rows, img_cols]; int[] arry = new int[total_size]; int count = 0; DirectoryInfo RootDir = new DirectoryInfo(TrainImagePath); foreach (var Dir in RootDir.GetDirectories()) { foreach (var file in Dir.GetFiles("*.png")) { Bitmap bmp = (Bitmap)Image.FromFile(file.FullName); if (bmp.Width != img_cols || bmp.Height != img_rows) { continue; } for (int row = 0; row < img_rows; row++) for (int col = 0; col < img_cols; col++) { var pixel = bmp.GetPixel(col, row); int val = (pixel.R + pixel.G + pixel.B) / 3; arrx[count, row, col] = val; arry[count] = int.Parse(Dir.Name); } count++; } Console.WriteLine($"Load image data count={count}"); } Console.WriteLine("LoadRawData finished"); //Save Data Console.WriteLine("Save data"); IFormatter serializer = new BinaryFormatter(); //開始序列化 FileStream saveFile = new FileStream(train_date_path, FileMode.Create, FileAccess.Write); serializer.Serialize(saveFile, arrx); saveFile.Close(); saveFile = new FileStream(train_label_path, FileMode.Create, FileAccess.Write); serializer.Serialize(saveFile, arry); saveFile.Close(); Console.WriteLine("Save data finished"); return (np.array(arrx), np.array(arry)); } /// <summary> /// 消費模型 /// </summary> private void test(Model model) { Random rand = new Random(1); DirectoryInfo TestDir = new DirectoryInfo(TestImagePath); foreach (var ChildDir in TestDir.GetDirectories()) { Console.WriteLine($"Folder:【{ChildDir.Name}】"); var Files = ChildDir.GetFiles("*.png"); for (int i = 0; i < 10; i++) { int index = rand.Next(1000); var image = Files[index]; var x = LoadImage(image.FullName); var pred_y = model.Apply(x); var result = argmax(pred_y[0].numpy()); Console.WriteLine($"FileName:{image.Name}\tPred:{result}"); } } } private NDArray LoadImage(string filename) { float[,,] arrx = new float[1, img_rows, img_cols]; Bitmap bmp = (Bitmap)Image.FromFile(filename); for (int row = 0; row < img_rows; row++) for (int col = 0; col < img_cols; col++) { var pixel = bmp.GetPixel(col, row); int val = (pixel.R + pixel.G + pixel.B) / 3; arrx[0, row, col] = val; } return np.array(arrx); } private int argmax(NDArray array) { var arr = array.reshape(-1); float max = 0; for (int i = 0; i < 10; i++) { if (arr[i] > max) { max = arr[i]; } } for (int i = 0; i < 10; i++) { if (arr[i] == max) { return i; } } return 0; } }
【相關資源】
原始碼:Git: https://gitee.com/seabluescn/tf_not.git
專案名稱:NN_MultipleClassification_Fashion_MNIST