ML.NET教程之計程車車費預測(迴歸問題)

Ken.W發表於2018-12-24

理解問題

計程車的車費不僅與距離有關,還涉及乘客數量,是否使用信用卡等因素(這是的計程車是指紐約市的)。所以並不是一個簡單的一元方程問題。

準備資料

建立一控制檯應用程式工程,新建Data資料夾,在其目錄下新增taxi-fare-train.csvtaxi-fare-test.csv檔案,不要忘了把它們的Copy to Output Directory屬性改為Copy if newer。之後,新增Microsoft.ML類庫包。

載入資料

新建MLContext物件,及建立TextLoader物件。TextLoader物件可用於從檔案中讀取資料。

MLContext mlContext = new MLContext(seed: 0);

_textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
{
    Separator = ",",
    HasHeader = true,
    Column = new[]
    {
        new TextLoader.Column("VendorId", DataKind.Text, 0),
        new TextLoader.Column("RateCode", DataKind.Text, 1),
        new TextLoader.Column("PassengerCount", DataKind.R4, 2),
        new TextLoader.Column("TripTime", DataKind.R4, 3),
        new TextLoader.Column("TripDistance", DataKind.R4, 4),
        new TextLoader.Column("PaymentType", DataKind.Text, 5),
        new TextLoader.Column("FareAmount", DataKind.R4, 6)
    }
});

提取特徵

資料集檔案裡共有七列,前六列做為特徵資料,最後一列是標記資料。

public class TaxiTrip
{
    [Column("0")]
    public string VendorId;

    [Column("1")]
    public string RateCode;

    [Column("2")]
    public float PassengerCount;

    [Column("3")]
    public float TripTime;

    [Column("4")]
    public float TripDistance;

    [Column("5")]
    public string PaymentType;

    [Column("6")]
    public float FareAmount;
}

public class TaxiTripFarePrediction
{
    [ColumnName("Score")]
    public float FareAmount;
}

訓練模型

首先讀取訓練資料集,其次建立管道。管道中第一步是把FareAmount列複製到Label列,做為標記資料。第二步,通過OneHotEncoding方式將VendorIdRateCodePaymentType三個字串型別列轉換成數值型別列。第三步,合併六個資料列為一個特徵資料列。最後一步,選擇FastTreeRegressionTrainer演算法做為訓練方法。
完成管道後,開始訓練模型。

IDataView dataView = _textLoader.Read(dataPath);
var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
    .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
    .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
    .Append(mlContext.Regression.Trainers.FastTree());
var model = pipeline.Fit(dataView);

評估模型

這裡要使用測試資料集,並用迴歸問題的Evaluate方法進行評估。

IDataView dataView = _textLoader.Read(_testDataPath);
var predictions = model.Transform(dataView);
var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
Console.WriteLine();
Console.WriteLine($"*************************************************");
Console.WriteLine($"*       Model quality metrics evaluation         ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
Console.WriteLine($"*       RMS loss:      {metrics.Rms:#.##}");

儲存模型

完成訓練的模型可以被儲存為zip檔案以備之後使用。

using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
    mlContext.Model.Save(model, fileStream);

使用模型

首先載入已經儲存的模型。接著建立預測函式物件,TaxiTrip為函式的輸入型別,TaxiTripFarePrediction為輸出型別。之後執行預測方法,傳入待測資料。

ITransformer loadedModel;
using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
{
    loadedModel = mlContext.Model.Load(stream);
}

var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext);

var taxiTripSample = new TaxiTrip()
{
    VendorId = "VTS",
    RateCode = "1",
    PassengerCount = 1,
    TripTime = 1140,
    TripDistance = 3.75f,
    PaymentType = "CRD",
    FareAmount = 0 // To predict. Actual/Observed = 15.5
};

var prediction = predictionFunction.Predict(taxiTripSample);

Console.WriteLine($"**********************************************************************");
Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
Console.WriteLine($"**********************************************************************");

完整示例程式碼

using Microsoft.ML;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using System;
using System.IO;

namespace TexiFarePredictor
{
    class Program
    {
        static readonly string _trainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-train.csv");
        static readonly string _testDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-test.csv");
        static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Data", "Model.zip");
        static TextLoader _textLoader;

        static void Main(string[] args)
        {
            MLContext mlContext = new MLContext(seed: 0);

            _textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
            {
                Separator = ",",
                HasHeader = true,
                Column = new[]
                {
                    new TextLoader.Column("VendorId", DataKind.Text, 0),
                    new TextLoader.Column("RateCode", DataKind.Text, 1),
                    new TextLoader.Column("PassengerCount", DataKind.R4, 2),
                    new TextLoader.Column("TripTime", DataKind.R4, 3),
                    new TextLoader.Column("TripDistance", DataKind.R4, 4),
                    new TextLoader.Column("PaymentType", DataKind.Text, 5),
                    new TextLoader.Column("FareAmount", DataKind.R4, 6)
                }
            });

            var model = Train(mlContext, _trainDataPath);

            Evaluate(mlContext, model);

            TestSinglePrediction(mlContext);

            Console.Read();
        }

        public static ITransformer Train(MLContext mlContext, string dataPath)
        {
            IDataView dataView = _textLoader.Read(dataPath);
            var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
                .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
                .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
                .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
                .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
                .Append(mlContext.Regression.Trainers.FastTree());
            var model = pipeline.Fit(dataView);
            SaveModelAsFile(mlContext, model);
            return model;
        }

        private static void SaveModelAsFile(MLContext mlContext, ITransformer model)
        {
            using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
                mlContext.Model.Save(model, fileStream);
        }

        private static void Evaluate(MLContext mlContext, ITransformer model)
        {
            IDataView dataView = _textLoader.Read(_testDataPath);
            var predictions = model.Transform(dataView);
            var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
            Console.WriteLine();
            Console.WriteLine($"*************************************************");
            Console.WriteLine($"*       Model quality metrics evaluation         ");
            Console.WriteLine($"*------------------------------------------------");
            Console.WriteLine($"*       R2 Score:      {metrics.RSquared:0.##}");
            Console.WriteLine($"*       RMS loss:      {metrics.Rms:#.##}");
        }

        private static void TestSinglePrediction(MLContext mlContext)
        {
            ITransformer loadedModel;
            using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
            {
                loadedModel = mlContext.Model.Load(stream);
            }

            var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext);

            var taxiTripSample = new TaxiTrip()
            {
                VendorId = "VTS",
                RateCode = "1",
                PassengerCount = 1,
                TripTime = 1140,
                TripDistance = 3.75f,
                PaymentType = "CRD",
                FareAmount = 0 // To predict. Actual/Observed = 15.5
            };

            var prediction = predictionFunction.Predict(taxiTripSample);

            Console.WriteLine($"**********************************************************************");
            Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
            Console.WriteLine($"**********************************************************************");
        }
    }
}

程式執行後顯示的結果:

*************************************************
*       Model quality metrics evaluation
*------------------------------------------------
*       R2 Score:      0.92
*       RMS loss:      2.81
**********************************************************************
Predicted fare: 15.7855, actual fare: 15.5
**********************************************************************

最後的預測結果還是比較符合實際數值的。

相關文章