ML.NET技術研究系列-2聚類演算法KMeans

Eric zhou發表於2019-07-14

上一篇博文我們介紹了ML.NET 的入門:

ML.NET技術研究系列1-入門篇

本文我們繼續,研究分享一下聚類演算法k-means.

一、k-means演算法簡介

k-means演算法是一種聚類演算法,所謂聚類,即根據相似性原則,將具有較高相似度的資料物件劃分至同一類簇,將具有較高相異度的資料物件劃分至不同類簇。

1. k-means演算法的原理是什麼樣的?參考:https://baijiahao.baidu.com/s?id=1622412414004300046&wfr=spider&for=pc

   k-means演算法中的k代表類簇個數,means代表類簇內資料物件的均值(這種均值是一種對類簇中心的描述),因此,k-means演算法又稱為k-均值演算法。

   k-means演算法是一種基於劃分的聚類演算法,以距離作為資料物件間相似性度量的標準,即資料物件間的距離越小,則它們的相似性越高,則它們越有可能在同一個類簇。

   資料物件間距離的計算有很多種,k-means演算法通常採用歐氏距離來計算資料物件間的距離。演算法詳細的流程描述如下:

   2. k-means演算法的優缺點:

     優點: 演算法簡單易實現; 

     缺點: 需要使用者事先指定類簇個數; 聚類結果對初始類簇中心的選取較為敏感; 容易陷入區域性最優; 只能發現球形類簇;

接下來我們說一下k-means演算法的經典應用場景:鳶尾花

二、鳶尾花

首先,鳶尾花是一種植物,有四個典型的屬性:

  •   花瓣長度
  •   花瓣寬度
  •   花萼長度
  •   花萼寬度

  

 鳶尾花有三大品種setosa、versicolor 或 virginica ,每個品種對應的以上四個屬性各不相同。

 鳶尾花資料集中一共包含了150條記錄,每個樣本的包含它的萼片長度和寬度,花瓣的長度和寬度以及這個樣本所屬的具體品種。每個品種的樣本量為50條。

 鳶尾花樣本資料格式:

5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica

 上述資料中,第一列是鳶尾花花萼長度,第二列是鳶尾花花萼寬度,第三列是鳶尾花花瓣長度,第四列是鳶尾花花瓣寬度。

   基於上述資料做機器學習、訓練,形成一個模型。

三、ML.NET k-means

   基於上述的場景,我們先準備樣本資料,https://github.com/dotnet/machinelearning/blob/master/test/data/iris.data

   另存為iris.data檔案,每個屬性逗號間隔。

   然後,大致梳理了一下實現步驟:

  1.    新建一個.Net  Core Console Project
  2.    新增Microsoft.ML nuget 1.2.0版本
  3.    新增鳶尾花資料、預測類實體類IrisData、ClusterPrediction
  4.    構造MLContext、從iris.data構造IDataView,採用Trainers.KMeans進行模型訓練,形成模型檔案:IrisClusteringModel.zip
  5.    輸入一個測試資料,進行預測。

   好,讓我們開始搞吧:

   1. 新建一個.Net  Core Console Project

    先看下用的VS的版本:

   

    新建一個.Net Core Console的Project KMeansDemo

   2. 新增Microsoft.ML nuget 1.2.0版本

   

  將iris.data檔案放到Project下的Data目錄中,同時右鍵iris.data,設定為:始終複製

  

  

 

 3. 新增鳶尾花資料、預測類實體類IrisData、ClusterPrediction

using System;
using System.Collections.Generic;
using System.Text;

namespace KMeansDemo
{
    using Microsoft.ML.Data;

    /// <summary>
    /// 鳶尾花資料
    /// </summary>
    class IrisData
    {
        /// <summary>
        /// 鳶尾花花萼長度
        /// </summary>
        [LoadColumn(0)]
        public float SepalLength;

        /// <summary>
        /// 鳶尾花花萼寬度
        /// </summary>
        [LoadColumn(1)]
        public float SepalWidth;

        /// <summary>
        /// 鳶尾花花瓣長度
        /// </summary>
        [LoadColumn(2)]
        public float PetalLength;

        /// <summary>
        /// 鳶尾花花瓣寬度
        /// </summary>
        [LoadColumn(3)]
        public float PetalWidth;
    }
}

 

using System;
using System.Collections.Generic;
using System.Text;

namespace KMeansDemo
{
    using Microsoft.ML.Data;

    public class ClusterPrediction
    {
        /// <summary>
        /// 預測的族群
        /// </summary>
        [ColumnName("PredictedLabel")]
        public uint PredictedClusterId;

        [ColumnName("Score")]
        public float[] Distances;
    }
}

4.  構造MLContext、從iris.data構造IDataView,採用Trainers.KMeans進行模型訓練,形成模型檔案:IrisClusteringModel.zip

 在Main函式中,開始編碼 ,首先新增引用

using Microsoft.ML;

 宣告樣本資料檔案和模型檔案的檔案路徑

static readonly string _dataPath = Path.Combine(Environment.CurrentDirectory, "Data", "iris.data");
static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Data", "IrisClusteringModel.zip");

 構造MLContext、IDataView,採用Trainer.KMeans進行模型訓練,形成模型檔案:IrisClusteringModel.zip

var mlContext = new MLContext(seed: 0);
IDataView dataView = mlContext.Data.LoadFromTextFile<IrisData>(_dataPath, hasHeader: false, separatorChar: ',');
string featuresColumnName = "Features";
var pipeline = mlContext.Transforms
                .Concatenate(featuresColumnName, "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
                .Append(mlContext.Clustering.Trainers.KMeans(featuresColumnName, numberOfClusters: 3));
var model = pipeline.Fit(dataView);
using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
{
      mlContext.Model.Save(model, dataView.Schema, fileStream);
}
Console.WriteLine("完成模型訓練!");
Console.WriteLine("模型檔案:"+ _modelPath);

5.  輸入一個測試資料,進行預測。

 輸入一個測試資料,使用生成的模型,進行預測:

var predictor = mlContext.Model.CreatePredictionEngine<IrisData, ClusterPrediction>(model);
var Setosa = new IrisData
{
                SepalLength = 5.1f,
                SepalWidth = 3.5f,
                PetalLength = 1.4f,
                PetalWidth = 0.2f
};

var prediction = predictor.Predict(Setosa);
Console.WriteLine($"Cluster: {prediction.PredictedClusterId}");
Console.WriteLine($"Distances: {string.Join(" ", prediction.Distances)}");
Console.WriteLine("Press any key!");

 全部的程式碼:

 1 using Microsoft.ML;
 2 using System;
 3 using System.IO;
 4 
 5 namespace KMeansDemo
 6 {
 7     class Program
 8     {
 9         static readonly string _dataPath = Path.Combine(Environment.CurrentDirectory, "Data", "iris.data");
10         static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Data", "IrisClusteringModel.zip");
11 
12         static void Main(string[] args)
13         {
14             var mlContext = new MLContext(seed: 0);
15             IDataView dataView = mlContext.Data.LoadFromTextFile<IrisData>(_dataPath, hasHeader: false, separatorChar: ',');
16             string featuresColumnName = "Features";
17             var pipeline = mlContext.Transforms
18                 .Concatenate(featuresColumnName, "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
19                 .Append(mlContext.Clustering.Trainers.KMeans(featuresColumnName, numberOfClusters: 3));
20             var model = pipeline.Fit(dataView);
21             using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
22             {
23                 mlContext.Model.Save(model, dataView.Schema, fileStream);
24             }
25             Console.WriteLine("完成模型訓練!");
26             Console.WriteLine("模型檔案:"+ _modelPath);
27             
28             //預測
29             var predictor = mlContext.Model.CreatePredictionEngine<IrisData, ClusterPrediction>(model);
30 
31             var Setosa = new IrisData
32             {
33                 SepalLength = 5.1f,
34                 SepalWidth = 3.5f,
35                 PetalLength = 1.4f,
36                 PetalWidth = 0.2f
37             };
38 
39             var prediction = predictor.Predict(Setosa);
40             Console.WriteLine($"Cluster: {prediction.PredictedClusterId}");
41             Console.WriteLine($"Distances: {string.Join(" ", prediction.Distances)}");
42             Console.WriteLine("Press any key!");
43         }
44     }
45 }

Run,看一下輸出:

 以上就是通過ML.NET 的KMeans演算法,實現聚類。

 上面的資料是一個監督學習的樣本,同時是一個數值型別的資料,比較好奇的是,能不能對文字資料+值資料進行聚類,下一篇,我們將繼續完成文字資料+值資料的聚類分析。

以上,分享給大家。

 

周國慶

2019/7/14

相關文章