寫在前面
準備近期將微軟的machinelearning-samples翻譯成中文,水平有限,如有錯漏,請大家多多指正。
如果有朋友對此感興趣,可以加入我:https://github.com/feiyun0112/machinelearning-samples.zh-cn
鳶尾花分類
ML.NET 版本 | API 型別 | 狀態 | 應用程式型別 | 資料型別 | 場景 | 機器學習任務 | 演算法 |
---|---|---|---|---|---|---|---|
v0.7 | 動態 API | 最新版本 | 控制檯應用程式 | .txt 檔案 | 鳶尾花分類 | 多類分類 | Sdca Multi-class |
在這個介紹性示例中,您將看到如何使用ML.NET來預測鳶尾花的型別。 在機器學習領域,這種型別的預測被稱為多類分類。
問題
這個問題集中在根據花瓣長度,花瓣寬度等花的引數預測鳶尾花(setosa,versicolor或virginica)的型別。
為了解決這個問題,我們將建立一個ML模型,它有4個輸入引數:
- petal length
- petal width
- sepal length
- sepal width
並預測該花屬於哪種鳶尾花型別:
- setosa
- versicolor
- virginica
確切地說,模型將返回花屬於每個型別的概率。
ML 任務 - 多類分類
多類分類的廣義問題是將專案分類為三個或更多類別中的一個。 (將專案分類為兩個類別之一稱為二元分類)。
多類分類的其他例子包括:
- 手寫數字識別:預測影象中包含10個數字(0~9)。
- 問題標記:預測問題屬於哪個類別(UI,後端,文件)。
- 根據患者的測試結果預測疾病階段。
所有這些例子的共同特點是我們要預測的引數可以取幾個(超過兩個)值中的一個。換句話說,這個值由enum
表示,而不是由integer
、float
、double
或boolean
型別表示。
解決方案
為了解決這個問題,首先我們將建立一個ML模型。然後,我們將在現有資料上訓練模型,評估其有多好,最後我們將使用該模型來預測鳶尾花型別。
1. 建立模型
建立模型包括:
- 使用
DataReader
上傳資料(iris-train.txt
) - 建立一個評估器並將資料轉換為一列,以便ML演算法(使用
Concatenate
)可以有效地使用它。 - 選擇學習演算法(
StochasticDualCoordinateAscent
)。
初始程式碼類似以下內容:
// Create MLContext to be shared across the model creation workflow objects
// Set a random seed for repeatable/deterministic results across multiple trainings.
var mlContext = new MLContext(seed: 0);
// STEP 1: Common data loading configuration
var textLoader = IrisTextLoaderFactory.CreateTextLoader(mlContext);
var trainingDataView = textLoader.Read(TrainDataPath);
var testDataView = textLoader.Read(TestDataPath);
// STEP 2: Common data process configuration with pipeline data transformations
var dataProcessPipeline = mlContext.Transforms.Concatenate("Features", "SepalLength",
"SepalWidth",
"PetalLength",
"PetalWidth" );
// STEP 3: Set the training algorithm, then create and config the modelBuilder
var modelBuilder = new Common.ModelBuilder<IrisData, IrisPrediction>(mlContext, dataProcessPipeline);
// We apply our selected Trainer
var trainer = mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(labelColumn: "Label", featureColumn: "Features");
modelBuilder.AddTrainer(trainer);
2. 訓練
訓練模型是在訓練資料(已知鳶尾花型別)上執行所選演算法以調整模型引數的過程。它在評估器物件中的Fit()
方法中實現。
為了執行訓練,我們只需呼叫方法時傳入在DataView物件中提供的訓練資料集(iris-train.txt檔案)。
// STEP 4: Train the model fitting to the DataSet
modelBuilder.Train(trainingDataView);
[...]
public ITransformer Train(IDataView trainingData)
{
TrainedModel = TrainingPipeline.Fit(trainingData);
return TrainedModel;
}
3. 評估模型
我們需要這一步來總結我們的模型對新資料的準確性。 為此,上一步中的模型針對另一個未在訓練中使用的資料集(iris-test.txt
)執行。 此資料集還包含已知的鳶尾花型別。
MulticlassClassification.Evaluate
計算模型預測的值和已知型別之間差異的各種指標。
var metrics = modelBuilder.EvaluateMultiClassClassificationModel(testDataView, "Label");
Common.ConsoleHelper.PrintMultiClassClassificationMetrics(trainer.ToString(), metrics);
[...]
public MultiClassClassifierEvaluator.Result EvaluateMultiClassClassificationModel(IDataView testData, string label="Label", string score="Score")
{
CheckTrained();
var predictions = TrainedModel.Transform(testData);
var metrics = _mlcontext.MulticlassClassification.Evaluate(predictions, label: label, score: score);
return metrics;
}
要了解關於如何理解指標的更多資訊,請參閱ML.NET指南 中的機器學習詞彙表,或者使用任何有關資料科學和機器學習的可用材料.
如果您對模型的質量不滿意,可以採用多種方法來改進,這將在examples類別中進行介紹。
4. 使用模型
在模型被訓練之後,我們可以使用Predict()
API來預測這種花屬於每個鳶尾花型別的概率。
var modelScorer = new Common.ModelScorer<IrisData, IrisPrediction>(mlContext);
modelScorer.LoadModelFromZipFile(ModelPath);
var prediction = modelScorer.PredictSingle(SampleIrisData.Iris1);
Console.WriteLine($"Actual: setosa. Predicted probability: setosa: {prediction.Score[0]:0.####}");
Console.WriteLine($" versicolor: {prediction.Score[1]:0.####}");
Console.WriteLine($" virginica: {prediction.Score[2]:0.####}");
[...]
public TPrediction PredictSingle(TObservation input)
{
CheckTrainedModelIsLoaded();
return PredictionFunction.Predict(input);
}
在TestIrisData.Iris1
中儲存有關我們想要預測型別的花的資訊。
internal class TestIrisData
{
internal static readonly IrisData Iris1 = new IrisData()
{
SepalLength = 3.3f,
SepalWidth = 1.6f,
PetalLength = 0.2f,
PetalWidth= 5.1f,
}
(...)
}