機器學習框架ML.NET學習筆記【2】入門之二元分類

seabluescn發表於2019-05-29

一、準備樣本

接上一篇文章提到的問題:根據一個人的身高、體重來判斷一個人的身材是否很好。但我手上沒有樣本資料,只能偽造一批資料了,偽造的資料比較標準,用來學習還是蠻合適的。

下面是我用來偽造資料的程式碼:

           string Filename = "./figure_full.csv";
            StreamWriter sw = new StreamWriter(Filename, false);
            sw.WriteLine("Height,Weight,Result");

            Random random = new Random();
            float height, weight;
            Result result;

            for (int i = 0; i < 2000; i++)
            {
                height = random.Next(150, 195);
                weight = random.Next(70, 200);

                if (height > 170 && weight < 120)
                    result = Result.Good;
                else
                    result = Result.Bad;
               
                sw.WriteLine($"{height},{weight},{(int)result}");
            }


   enum Result
    {
        Bad=0,
        Good=1
    }
View Code

製造成功後的資料如下:

 用記事本開啟:

 

二、原始碼

資料準備好了,我們就用準備好的資料進行學習了,先貼出全部程式碼,然後再逐一解釋:

namespace BinaryClassification_Figure
{
    class Program
    {
        static readonly string DataPath = Path.Combine(Environment.CurrentDirectory, "Data", "figure_full.csv");
        static readonly string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "FastTree_Model.zip");

        static void Main(string[] args)
        {
            TrainAndSave();
            LoadAndPrediction();

            Console.WriteLine("Press any to exit!");
            Console.ReadKey();
        }

        static void TrainAndSave()
        {
            MLContext mlContext = new MLContext();           

            //準備資料
            var fulldata = mlContext.Data.LoadFromTextFile<FigureData>(path: DataPath, hasHeader: true, separatorChar: ',');           
            var trainTestData = mlContext.Data.TrainTestSplit(fulldata,testFraction:0.2);
            var trainData = trainTestData.TrainSet;
            var testData = trainTestData.TestSet;

            //訓練 
            IEstimator<ITransformer> dataProcessPipeline = mlContext.Transforms.Concatenate("Features", new[] { "Height", "Weight" })
                .Append(mlContext.Transforms.NormalizeMeanVariance(inputColumnName: "Features", outputColumnName: "FeaturesNormalizedByMeanVar"));
            IEstimator<ITransformer> trainer = mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Result", featureColumnName: "FeaturesNormalizedByMeanVar");
            IEstimator<ITransformer> trainingPipeline = dataProcessPipeline.Append(trainer); 
            ITransformer model = trainingPipeline.Fit(trainData);

            //評估
            var predictions = model.Transform(testData);
            var metrics = mlContext.BinaryClassification.Evaluate(data: predictions, labelColumnName: "Result", scoreColumnName: "Score");
            PrintBinaryClassificationMetrics(trainer.ToString(), metrics);

            //儲存模型
            mlContext.Model.Save(model, trainData.Schema, ModelPath);
            Console.WriteLine($"Model file saved to :{ModelPath}");
        }      

        static void LoadAndPrediction()
        {
            var mlContext = new MLContext();
            ITransformer model = mlContext.Model.Load(ModelPath, out var inputSchema);
            var predictionEngine = mlContext.Model.CreatePredictionEngine<FigureData, FigureDatePredicted>(model);

            FigureData test = new FigureData();
            test.Weight = 115;
            test.Height = 171;

            var prediction = predictionEngine.Predict(test);
            Console.WriteLine($"Predict Result :{prediction.PredictedLabel}");
        }      
    }

    public class FigureData
    {
        [LoadColumn(0)]
        public float Height { get; set; }

        [LoadColumn(1)]
        public float Weight { get; set; }

        [LoadColumn(2)]
        public bool Result { get; set; }       
    }

    public class FigureDatePredicted : FigureData
    {
        public bool PredictedLabel;
    }
}
View Code

 

三、對程式碼的解釋

1、讀取樣本資料

        string DataPath = Path.Combine(Environment.CurrentDirectory, "Data", "figure_full.csv");
        MLContext mlContext = new MLContext();          

            //準備資料
            var fulldata = mlContext.Data.LoadFromTextFile<FigureData>(path: DataPath, hasHeader: true, separatorChar: ',');           
            var trainTestData = mlContext.Data.TrainTestSplit(fulldata,testFraction:0.2);
            var trainData = trainTestData.TrainSet;
            var testData = trainTestData.TestSet;    

 LoadFromTextFile<FigureData>(path: DataPath, hasHeader: true, separatorChar: ',')用來讀取資料到DataView

FigureData類是和樣本資料對應的實體類,LoadColumn特性指示該屬性對應該條資料中的第幾個資料。

    public class FigureData
    {
        [LoadColumn(0)]
        public float Height { get; set; }

        [LoadColumn(1)]
        public float Weight { get; set; }

        [LoadColumn(2)]
        public bool Result { get; set; }       
    }

 path:檔案路徑

hasHeader:文字檔案是否包含標題

separatorChar:用來分割資料的字元,我們用的是逗號,常用的還有跳格符‘\t’

TrainTestSplit(fulldata,testFraction:0.2)用來隨機分割資料,分成學習資料和評估用的資料,通常情況,如果資料較多,測試資料取20%左右比較合適,如果資料量較少,測試資料取10%左右比較合適。

如果不通過分割,準備兩個資料檔案,一個用來訓練、一個用來評估,效果是一樣的。

 

2、訓練 

            //訓練 
            IEstimator<ITransformer> dataProcessPipeline = mlContext.Transforms.Concatenate("Features", new[] { "Height", "Weight" })
                .Append(mlContext.Transforms.NormalizeMeanVariance(inputColumnName: "Features", outputColumnName: "FeaturesNormalizedByMeanVar"));
            IEstimator<ITransformer> trainer = mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Result", featureColumnName: "FeaturesNormalizedByMeanVar");
            IEstimator<ITransformer> trainingPipeline = dataProcessPipeline.Append(trainer); 
            ITransformer model = trainingPipeline.Fit(trainData);

  IDataView這個資料集就類似一個表格,它的列(Column)是可以動態增加的,一開始我們通過LoadFromTextFile獲得的資料集包括:Height、Weight、Result這幾個列,在進行訓練之前,我們還要對這個資料集進行處理,形成符合我們要求的資料集。

Concatenate這個方法是把多個列,組合成一個列,因為二元分類的機器學習演算法只接收一個特徵列,所以要把多個特徵列(Height、Weight)組合成一個特徵列Features(組合的結果應該是個float陣列)。

NormalizeMeanVariance是對列進行歸一化處理,這裡輸入列為:Features,輸出列為:FeaturesNormalizedByMeanVar,歸一化的含義見本文最後一節介紹。

資料集就緒以後,就要選擇學習演算法,針對二元分類,我們選擇了快速決策樹演算法FastTree,我們需要告訴這個演算法特徵值放在哪個列裡面(FeaturesNormalizedByMeanVar),標籤值放在哪個列裡面(Result)。

連結資料處理管道和演算法形成學習管道,將資料集中的資料逐一通過學習管道進行學習,形成機器學習模型。

有了這個模型我們就可以通過它進行實際應用了。但我們一般不會現在就使用這個模型,我們需要先評估一下這個模型,然後把模型儲存下來。以後應用時再通過檔案讀取出模型,然後進行應用,這樣就不用等待學習的時間了,通常學習的時間都比較長。

 

3、評估 

            //評估
            var predictions = model.Transform(testData);
            var metrics = mlContext.BinaryClassification.Evaluate(data: predictions, labelColumnName: "Result");
            PrintBinaryClassificationMetrics(trainer.ToString(), metrics);

  評估的過程就是對測試資料集進行批量轉換(Transform),轉換過的資料集會多出一個“PredictedLabel”的列,這個就是模型評估的結果,逐條將這個結果和實際結果(Result)進行比較,就最終形成了效果評估資料。

我們可以列印這個評估結果,檢視其成功率,一般成功率大於97%就是比較好的模型了。由於我們偽造的資料比較整齊,所以我們這次評估的成功率為100%。

注意:評估過程不會提升現有的模型能力,只是對現有模型的一種檢測。

 

4、儲存模型 

//儲存模型
           string ModelPath = Path.Combine(Environment.CurrentDirectory, "Data", "FastTree_Model.zip");
            mlContext.Model.Save(model, trainData.Schema, ModelPath);
            Console.WriteLine($"Model file saved to :{ModelPath}");

 這個沒啥好解釋的。

 

5、讀取模型並建立預測引擎 

           //讀取模型
            var mlContext = new MLContext();
            ITransformer model = mlContext.Model.Load(ModelPath, out var inputSchema);

            //建立預測引擎
            var predictionEngine = mlContext.Model.CreatePredictionEngine<FigureData, FigureDatePredicted>(model);

 建立預測引擎的功能和Transform是類似的,不過Transform是處理批量記錄,這裡只處理一條資料,而且這裡的輸入輸出是實體物件,定義如下:

   public class FigureData
    {
        [LoadColumn(0)]
        public float Height { get; set; }

        [LoadColumn(1)]
        public float Weight { get; set; }

        [LoadColumn(2)]
        public bool Result { get; set; }       
    }

    public class FigureDatePredicted : FigureData
    {
        public bool PredictedLabel;
    }

 由於預測結果裡放在“PredictedLabel”欄位中,所以FigureDatePredicted類必須要包含PredictedLabel屬性,目前FigureDatePredicted 類是從FigureData類繼承的,由於我們只用到PredictedLabel屬性,所以不繼承也沒有關係,如果繼承的話,後面要除錯的話會方便一點。

 

6、應用 

            FigureData test = new FigureData
            {
                Weight = 115,
                Height = 171
            };

            var prediction = predictionEngine.Predict(test);
            Console.WriteLine($"Predict Result :{prediction.PredictedLabel}");

 這部分程式碼就比較簡單,test是我們要預測的物件,預測後列印出預測結果。

 

四、附:資料歸一化

 機器學習的演算法中一般會有很多的乘法運算,當運算的數字過大時,很容易在多次運算後溢位,為了防止這種情況,就要對資料進行歸一化處理。歸一化的目標就是把參與運算的特徵數變為(0,1)或(-1,1)之間的浮點數,常見的處理方式有:min-max標準化、Log函式轉換、對數函式轉換等。

我們這次採用的是平均方差歸一化方法。

 

五、資源

原始碼下載地址:https://github.com/seabluescn/Study_ML.NET

工程名稱:BinaryClassification_Figure

相關文章