ML.NET 示例:聚類之鳶尾花

feiyun0112發表於2018-12-15

寫在前面

準備近期將微軟的machinelearning-samples翻譯成中文,水平有限,如有錯漏,請大家多多指正。
如果有朋友對此感興趣,可以加入我:https://github.com/feiyun0112/machinelearning-samples.zh-cn


聚類鳶尾花資料

ML.NET 版本 API 型別 狀態 應用程式型別 資料型別 場景 機器學習任務 演算法
v0.7 動態 API 最新版 控制檯應用程式 .txt 檔案 聚類鳶尾花 聚類 K-means++

在這個介紹性示例中,您將看到如何使用ML.NET將不同型別鳶尾花劃分為不同組。在機器學習的世界中,這個任務被稱為群集

問題

為了演示聚類API的實際作用,我們將使用三種型別的鳶尾花:setosa、versicolor和versicolor。它們都儲存在相同的資料集中。儘管這些花的型別是已知的,我們將不使用它,只對花的引數,如花瓣長度,花瓣寬度等執行聚類演算法。這個任務是把所有的花分成三個不同的簇。我們期望不同型別的花屬於不同的簇。

模型的輸入使用下列鳶尾花引數:

  • petal length
  • petal width
  • sepal length
  • sepal width

ML 任務 - 聚類

聚類的一般問題是將一組物件分組,使得同一組中的物件彼此之間的相似性大於其他組中的物件。

其他一些聚類示例:

  • 將新聞文章分為不同主題:體育,政治,科技等。
  • 按購買偏好對客戶進行分組。
  • 將數字影象劃分為不同的區域以進行邊界檢測或物體識別。

聚類看起來類似於多類分類,但區別在於對於聚類任務,我們不知道過去資料的答案。 因此,沒有“導師”/“主管”可以判斷我們的演算法的預測是對還是錯。 這種型別的ML任務稱為無監督學習

解決方案

要解決這個問題,首先我們將建立並訓練ML模型。 然後我們將使用訓練模型來預測鳶尾花的簇。

1. 建立模型

建立模型包括:上傳資料(使用TextLoader載入iris-full.txt),轉換資料以便ML演算法(使用Concatenate)有效地使用,並選擇學習演算法(KMeans)。 所有這些步驟都儲存在trainingPipeline中:

//Create the MLContext to share across components for deterministic results
MLContext mlContext = new MLContext(seed: 1);  //Seed set to any number so you have a deterministic environment

// STEP 1: Common data loading configuration
TextLoader textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
                                {
                                    Separator = "\t",
                                    HasHeader = true,
                                    Column = new[]
                                                {
                                                    new TextLoader.Column("Label", DataKind.R4, 0),
                                                    new TextLoader.Column("SepalLength", DataKind.R4, 1),
                                                    new TextLoader.Column("SepalWidth", DataKind.R4, 2),
                                                    new TextLoader.Column("PetalLength", DataKind.R4, 3),
                                                    new TextLoader.Column("PetalWidth", DataKind.R4, 4),
                                                }
                                });

IDataView fullData = textLoader.Read(DataPath);

//STEP 2: Process data transformations in pipeline
var dataProcessPipeline = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth");

// STEP 3: Create and train the model     
var trainer = mlContext.Clustering.Trainers.KMeans(features: "Features", clustersCount: 3);
var trainingPipeline = dataProcessPipeline.Append(trainer);

2. 訓練模型

訓練模型是在給定資料上執行所選演算法的過程。 要執行訓練,您需要呼叫Fit()方法。

var trainedModel = trainingPipeline.Fit(trainingDataView);

3. 使用模型

在建立和訓練模型之後,我們可以使用Predict()API來預測鳶尾花的簇,並計算從給定花引數到每個簇(簇的每個質心)的距離。

                // Test with one sample text 
                var sampleIrisData = new IrisData()
                {
                    SepalLength = 3.3f,
                    SepalWidth = 1.6f,
                    PetalLength = 0.2f,
                    PetalWidth = 5.1f,
                };

                // Create prediction engine related to the loaded trained model
                var predFunction = trainedModel.MakePredictionFunction<IrisData, IrisPrediction>(mlContext);

                //Score
                var resultprediction = predFunction.Predict(sampleIrisData);
                
                Console.WriteLine($"Cluster assigned for setosa flowers:" + resultprediction.SelectedClusterId);

相關文章