ML.NET 示例:多類分類之鳶尾花分類

feiyun0112發表於2018-12-07

寫在前面

準備近期將微軟的machinelearning-samples翻譯成中文,水平有限,如有錯漏,請大家多多指正。
如果有朋友對此感興趣,可以加入我:https://github.com/feiyun0112/machinelearning-samples.zh-cn

鳶尾花分類

ML.NET 版本 API 型別 狀態 應用程式型別 資料型別 場景 機器學習任務 演算法
v0.7 動態 API 最新版本 控制檯應用程式 .txt 檔案 鳶尾花分類 多類分類 Sdca Multi-class

在這個介紹性示例中,您將看到如何使用ML.NET來預測鳶尾花的型別。 在機器學習領域,這種型別的預測被稱為多類分類

問題

這個問題集中在根據花瓣長度,花瓣寬度等花的引數預測鳶尾花(setosa,versicolor或virginica)的型別。

為了解決這個問題,我們將建立一個ML模型,它有4個輸入引數:

  • petal length
  • petal width
  • sepal length
  • sepal width

並預測該花屬於哪種鳶尾花型別:

  • setosa
  • versicolor
  • virginica

確切地說,模型將返回花屬於每個型別的概率。

ML 任務 - 多類分類

多類分類的廣義問題是將專案分類為三個或更多類別中的一個。 (將專案分類為兩個類別之一稱為二元分類)。

多類分類的其他例子包括:

  • 手寫數字識別:預測影象中包含10個數字(0~9)。
  • 問題標記:預測問題屬於哪個類別(UI,後端,文件)。
  • 根據患者的測試結果預測疾病階段。

所有這些例子的共同特點是我們要預測的引數可以取幾個(超過兩個)值中的一個。換句話說,這個值由enum表示,而不是由integerfloatdoubleboolean型別表示。

解決方案

為了解決這個問題,首先我們將建立一個ML模型。然後,我們將在現有資料上訓練模型,評估其有多好,最後我們將使用該模型來預測鳶尾花型別。

Build -> Train -> Evaluate -> Consume

1. 建立模型

建立模型包括:

  • 使用DataReader上傳資料(iris-train.txt
  • 建立一個評估器並將資料轉換為一列,以便ML演算法(使用Concatenate)可以有效地使用它。
  • 選擇學習演算法(StochasticDualCoordinateAscent)。

初始程式碼類似以下內容:

// Create MLContext to be shared across the model creation workflow objects 
// Set a random seed for repeatable/deterministic results across multiple trainings.
var mlContext = new MLContext(seed: 0);

// STEP 1: Common data loading configuration
var textLoader = IrisTextLoaderFactory.CreateTextLoader(mlContext);
var trainingDataView = textLoader.Read(TrainDataPath);
var testDataView = textLoader.Read(TestDataPath);

// STEP 2: Common data process configuration with pipeline data transformations
var dataProcessPipeline = mlContext.Transforms.Concatenate("Features", "SepalLength",
                                                                       "SepalWidth",
                                                                       "PetalLength",
                                                                       "PetalWidth" );

// STEP 3: Set the training algorithm, then create and config the modelBuilder                            
var modelBuilder = new Common.ModelBuilder<IrisData, IrisPrediction>(mlContext, dataProcessPipeline);
// We apply our selected Trainer 
var trainer = mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(labelColumn: "Label", featureColumn: "Features");
modelBuilder.AddTrainer(trainer);

2. 訓練

訓練模型是在訓練資料(已知鳶尾花型別)上執行所選演算法以調整模型引數的過程。它在評估器物件中的Fit() 方法中實現。

為了執行訓練,我們只需呼叫方法時傳入在DataView物件中提供的訓練資料集(iris-train.txt檔案)。

// STEP 4: Train the model fitting to the DataSet            
modelBuilder.Train(trainingDataView);

[...]
public ITransformer Train(IDataView trainingData)
{
    TrainedModel = TrainingPipeline.Fit(trainingData);
    return TrainedModel;
}

3. 評估模型

我們需要這一步來總結我們的模型對新資料的準確性。 為此,上一步中的模型針對另一個未在訓練中使用的資料集(iris-test.txt)執行。 此資料集還包含已知的鳶尾花型別。
MulticlassClassification.Evaluate計算模型預測的值和已知型別之間差異的各種指標。

var metrics = modelBuilder.EvaluateMultiClassClassificationModel(testDataView, "Label");
Common.ConsoleHelper.PrintMultiClassClassificationMetrics(trainer.ToString(), metrics);
    
[...]
public MultiClassClassifierEvaluator.Result EvaluateMultiClassClassificationModel(IDataView testData, string label="Label", string score="Score")
{
    CheckTrained();
    var predictions = TrainedModel.Transform(testData);
    var metrics = _mlcontext.MulticlassClassification.Evaluate(predictions, label: label, score: score);
    return metrics;
}

要了解關於如何理解指標的更多資訊,請參閱ML.NET指南 中的機器學習詞彙表,或者使用任何有關資料科學和機器學習的可用材料.

如果您對模型的質量不滿意,可以採用多種方法來改進,這將在examples類別中進行介紹。

4. 使用模型

在模型被訓練之後,我們可以使用Predict() API來預測這種花屬於每個鳶尾花型別的概率。

var modelScorer = new Common.ModelScorer<IrisData, IrisPrediction>(mlContext);
modelScorer.LoadModelFromZipFile(ModelPath);

var prediction = modelScorer.PredictSingle(SampleIrisData.Iris1);
Console.WriteLine($"Actual: setosa.     Predicted probability: setosa:      {prediction.Score[0]:0.####}");
Console.WriteLine($"                                           versicolor:  {prediction.Score[1]:0.####}");
Console.WriteLine($"                                           virginica:   {prediction.Score[2]:0.####}");

[...]
public TPrediction PredictSingle(TObservation input)
{
    CheckTrainedModelIsLoaded();
    return PredictionFunction.Predict(input);
}

TestIrisData.Iris1中儲存有關我們想要預測型別的花的資訊。

internal class TestIrisData
{
    internal static readonly IrisData Iris1 = new IrisData()
    {
        SepalLength = 3.3f,
        SepalWidth = 1.6f,
        PetalLength = 0.2f,
        PetalWidth= 5.1f,
    }
    (...)
}

相關文章