C#中的深度學習(二):預處理識別硬幣的資料集

碼農譯站發表於2020-12-22

在文章中,我們將對輸入到機器學習模型中的資料集進行預處理。

這裡我們將對一個硬幣資料集進行預處理,以便以後在監督學習模型中進行訓練。在機器學習中預處理資料集通常涉及以下任務:

  1. 清理資料——通過對周圍資料的平均值或使用其他策略來填補資料缺失或損壞造成的漏洞。
  2. 規範資料——將資料縮放值標準化到一個標準範圍,通常是0到1。具有廣泛值範圍的資料可能會導致不規範,因此我們將所有資料都放在一個公共範圍內。
  3. 一種熱編碼標籤——將資料集中物件的標籤或類編碼為N維二進位制向量,其中N是類的總數。陣列元素都被設定為0,除了與物件的類相對應的元素,它被設定為1。這意味著在每個陣列中都有一個值為1的元素。
  4. 將輸入資料集分為訓練集和驗證集——訓練集被用於訓練模型,驗證集是用於檢查我們的訓練結果。

這個例子我們將使用Numpy.NET,它基本上是Python中流行的Numpy庫的.NET版本。

Numpy是一個專注於處理矩陣的庫。

為了實現我們的資料集處理器,我們在PreProcessing資料夾中建立Utils類和DataSet類。Utils類合併了一個靜態Normalize 方法,如下所示:

public class Utils
   {
       public static NDarray Normalize(string path)
       {
           var colorMode = Settings.Channels == 3 ? "rgb" : "grayscale";
           var img = ImageUtil.LoadImg(path, color_mode: colorMode, target_size: (Settings.ImgWidth, Settings.ImgHeight));
           return ImageUtil.ImageToArray(img) / 255;
       }

   }

在這種方法中,我們用給定的顏色模式(RGB或灰度)載入影像,並將其調整為給定的寬度和高度。然後我們返回包含影像的矩陣,每個元素除以255。每個元素除以255是使它們標準化,因為影像中任何畫素的值都在0到255之間,所以通過將它們除以255,我們確保了新的範圍是0到1,包括255。

我們還在程式碼中使用了一個Settings類。該類包含用於跨應用程式使用的許多常量。另一個類DataSet,表示我們將要用來訓練機器學習模型的資料集。這裡我們有以下欄位:

  1. _pathToFolder—包含影像的資料夾的路徑。
  2. _extList—要考慮的副檔名列表。
  3. _labels—_pathToFolder中影像的標籤或類。
  4. _objs -影像本身,表示為Numpy.NDarray。
  5. _validationSplit—用於將總影像數劃分為驗證集和訓練集的百分比,在本例中,百分比將定義驗證集與總影像數之間的大小。
  6. NumberClasses-資料集中唯一類的總數。
  7. TrainX -訓練資料,表示為Numpy.NDarray。
  8. TrainY -訓練標籤,表示為Numpy.NDarray。
  9. ValidationX—驗證資料,表示為Numpy.NDarray。
  10. ValidationY-驗證標籤,表示為Numpy.NDarray。

這是DataSet類:

public class DataSet
    {
        private string _pathToFolder;
        private string[] _extList;
        private List<int> _labels;
        private List<NDarray> _objs;
        private double _validationSplit;
        public int NumberClasses { get; set; }
        public NDarray TrainX { get; set; }
        public NDarray ValidationX { get; set; }
        public NDarray TrainY { get; set; }
        public NDarray ValidationY { get; set; }

        public DataSet(string pathToFolder, string[] extList, int numberClasses, double validationSplit)
        {
            _pathToFolder = pathToFolder;
            _extList = extList;
            NumberClasses = numberClasses;
            _labels = new List<int>();
            _objs = new List<NDarray>();
            _validationSplit = validationSplit;
        }

        public void LoadDataSet()
        {
            // Process the list of files found in the directory.
            string[] fileEntries = Directory.GetFiles(_pathToFolder);
            foreach (string fileName in fileEntries)
                if (IsRequiredExtFile(fileName))
                    ProcessFile(fileName);

            MapToClassRange();
            GetTrainValidationData();
        }

        private bool IsRequiredExtFile(string fileName)
        {
            foreach (var ext in _extList)
            {
                if (fileName.Contains("." + ext))
                {
                    return true;
                }
            }

            return false;
        }

        private void MapToClassRange()
        {
            HashSet<int> uniqueLabels = _labels.ToHashSet();
            var uniqueLabelList = uniqueLabels.ToList();
            uniqueLabelList.Sort();

            _labels = _labels.Select(x => uniqueLabelList.IndexOf(x)).ToList();
        }

        private NDarray OneHotEncoding(List<int> labels)
        {
            var npLabels = np.array(labels.ToArray()).reshape(-1);
            return Util.ToCategorical(npLabels, num_classes: NumberClasses);
        }

        private void ProcessFile(string path)
        {
            _objs.Add(Utils.Normalize(path));
            ProcessLabel(Path.GetFileName(path));
        }

        private void ProcessLabel(string filename)
        {
            _labels.Add(int.Parse(ExtractClassFromFileName(filename)));
        }

        private string ExtractClassFromFileName(string filename)
        {
            return filename.Split('_')[0].Replace("class", "");
        }

        private void GetTrainValidationData()
        {
            var listIndices = Enumerable.Range(0, _labels.Count).ToList();
            var toValidate = _objs.Count * _validationSplit;
            var random = new Random();
            var xValResult = new List<NDarray>();
            var yValResult = new List<int>();
            var xTrainResult = new List<NDarray>();
            var yTrainResult = new List<int>();

            // Split validation data
            for (var i = 0; i < toValidate; i++)
            {
                var randomIndex = random.Next(0, listIndices.Count);
                var indexVal = listIndices[randomIndex];
                xValResult.Add(_objs[indexVal]);
                yValResult.Add(_labels[indexVal]);
                listIndices.RemoveAt(randomIndex);
            }

            // Split rest (training data)
            listIndices.ForEach(indexVal => 
            { 
                xTrainResult.Add(_objs[indexVal]);
                yTrainResult.Add(_labels[indexVal]);
            });

            TrainY = OneHotEncoding(yTrainResult);
            ValidationY = OneHotEncoding(yValResult);
            TrainX = np.array(xTrainResult);
            ValidationX = np.array(xValResult);
        }
}

下面是每個方法的說明:

  1. LoadDataSet()——類的主方法,我們呼叫它來載入_pathToFolder中的資料集。它呼叫下面列出的其他方法來完成此操作。
  2. IsRequiredExtFile(filename) - 檢查給定檔案是否包含至少一個應該為該資料集處理的副檔名(在_extList中列出)。
  3. MapToClassRange() -獲取資料集中唯一標籤的列表。
  4. ProcessFile(path) -使用Utils.Normalize方法對影像進行規格化,並呼叫ProcessLabel方法。
  5. ProcessLabel(filename)——將ExtractClassFromFileName方法的結果新增為標籤。
  6. ExtractClassFromFileName(filename) -從影像的檔名中提取類。
  7. GetTrainValidationData()——將資料集劃分為訓練子資料集和驗證子資料集。

在本系列中,我們將使用https://cvl.tuwien.ac.at/research/cvl-databases/coin-image-dataset/上的硬幣影像資料集。

要載入資料集,我們可以在控制檯應用程式的主類中包含以下內容:

var numberClasses = 60;
var fileExt = new string[] { ".png" };
var dataSetFilePath = @"C:/Users/arnal/Downloads/coin_dataset";
var dataSet = new PreProcessing.DataSet(dataSetFilePath, fileExt, numberClasses, 0.2);
dataSet.LoadDataSet();

我們的資料現在可以輸入到機器學習模型中。下一篇文章將介紹監督機器學習的基礎知識,以及訓練和驗證階段包括哪些內容。它是為沒有AI經驗的讀者準備的。

歡迎關注我的公眾號,如果你有喜歡的外文技術文章,可以通過公眾號留言推薦給我。

原文連結:https://www.codeproject.com/Articles/5284219/Deep-Learning-in-Csharp-Coin-Detection-Using-OpenC

相關文章