理解問題
計程車的車費不僅與距離有關,還涉及乘客數量,是否使用信用卡等因素(這是的計程車是指紐約市的)。所以並不是一個簡單的一元方程問題。
準備資料
建立一控制檯應用程式工程,新建Data
資料夾,在其目錄下新增taxi-fare-train.csv與taxi-fare-test.csv檔案,不要忘了把它們的Copy to Output Directory
屬性改為Copy if newer
。之後,新增Microsoft.ML類庫包。
載入資料
新建MLContext物件,及建立TextLoader物件。TextLoader物件可用於從檔案中讀取資料。
MLContext mlContext = new MLContext(seed: 0);
_textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
{
Separator = ",",
HasHeader = true,
Column = new[]
{
new TextLoader.Column("VendorId", DataKind.Text, 0),
new TextLoader.Column("RateCode", DataKind.Text, 1),
new TextLoader.Column("PassengerCount", DataKind.R4, 2),
new TextLoader.Column("TripTime", DataKind.R4, 3),
new TextLoader.Column("TripDistance", DataKind.R4, 4),
new TextLoader.Column("PaymentType", DataKind.Text, 5),
new TextLoader.Column("FareAmount", DataKind.R4, 6)
}
});
提取特徵
資料集檔案裡共有七列,前六列做為特徵資料,最後一列是標記資料。
public class TaxiTrip
{
[Column("0")]
public string VendorId;
[Column("1")]
public string RateCode;
[Column("2")]
public float PassengerCount;
[Column("3")]
public float TripTime;
[Column("4")]
public float TripDistance;
[Column("5")]
public string PaymentType;
[Column("6")]
public float FareAmount;
}
public class TaxiTripFarePrediction
{
[ColumnName("Score")]
public float FareAmount;
}
訓練模型
首先讀取訓練資料集,其次建立管道。管道中第一步是把FareAmount
列複製到Label
列,做為標記資料。第二步,通過OneHotEncoding方式將VendorId
,RateCode
,PaymentType
三個字串型別列轉換成數值型別列。第三步,合併六個資料列為一個特徵資料列。最後一步,選擇FastTreeRegressionTrainer演算法做為訓練方法。
完成管道後,開始訓練模型。
IDataView dataView = _textLoader.Read(dataPath);
var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
.Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
.Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
.Append(mlContext.Regression.Trainers.FastTree());
var model = pipeline.Fit(dataView);
評估模型
這裡要使用測試資料集,並用迴歸問題的Evaluate
方法進行評估。
IDataView dataView = _textLoader.Read(_testDataPath);
var predictions = model.Transform(dataView);
var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
Console.WriteLine();
Console.WriteLine($"*************************************************");
Console.WriteLine($"* Model quality metrics evaluation ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}");
Console.WriteLine($"* RMS loss: {metrics.Rms:#.##}");
儲存模型
完成訓練的模型可以被儲存為zip檔案以備之後使用。
using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
mlContext.Model.Save(model, fileStream);
使用模型
首先載入已經儲存的模型。接著建立預測函式物件,TaxiTrip
為函式的輸入型別,TaxiTripFarePrediction
為輸出型別。之後執行預測方法,傳入待測資料。
ITransformer loadedModel;
using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
{
loadedModel = mlContext.Model.Load(stream);
}
var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext);
var taxiTripSample = new TaxiTrip()
{
VendorId = "VTS",
RateCode = "1",
PassengerCount = 1,
TripTime = 1140,
TripDistance = 3.75f,
PaymentType = "CRD",
FareAmount = 0 // To predict. Actual/Observed = 15.5
};
var prediction = predictionFunction.Predict(taxiTripSample);
Console.WriteLine($"**********************************************************************");
Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
Console.WriteLine($"**********************************************************************");
完整示例程式碼
using Microsoft.ML;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using System;
using System.IO;
namespace TexiFarePredictor
{
class Program
{
static readonly string _trainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-train.csv");
static readonly string _testDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-test.csv");
static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Data", "Model.zip");
static TextLoader _textLoader;
static void Main(string[] args)
{
MLContext mlContext = new MLContext(seed: 0);
_textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
{
Separator = ",",
HasHeader = true,
Column = new[]
{
new TextLoader.Column("VendorId", DataKind.Text, 0),
new TextLoader.Column("RateCode", DataKind.Text, 1),
new TextLoader.Column("PassengerCount", DataKind.R4, 2),
new TextLoader.Column("TripTime", DataKind.R4, 3),
new TextLoader.Column("TripDistance", DataKind.R4, 4),
new TextLoader.Column("PaymentType", DataKind.Text, 5),
new TextLoader.Column("FareAmount", DataKind.R4, 6)
}
});
var model = Train(mlContext, _trainDataPath);
Evaluate(mlContext, model);
TestSinglePrediction(mlContext);
Console.Read();
}
public static ITransformer Train(MLContext mlContext, string dataPath)
{
IDataView dataView = _textLoader.Read(dataPath);
var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
.Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
.Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
.Append(mlContext.Regression.Trainers.FastTree());
var model = pipeline.Fit(dataView);
SaveModelAsFile(mlContext, model);
return model;
}
private static void SaveModelAsFile(MLContext mlContext, ITransformer model)
{
using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
mlContext.Model.Save(model, fileStream);
}
private static void Evaluate(MLContext mlContext, ITransformer model)
{
IDataView dataView = _textLoader.Read(_testDataPath);
var predictions = model.Transform(dataView);
var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
Console.WriteLine();
Console.WriteLine($"*************************************************");
Console.WriteLine($"* Model quality metrics evaluation ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}");
Console.WriteLine($"* RMS loss: {metrics.Rms:#.##}");
}
private static void TestSinglePrediction(MLContext mlContext)
{
ITransformer loadedModel;
using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
{
loadedModel = mlContext.Model.Load(stream);
}
var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext);
var taxiTripSample = new TaxiTrip()
{
VendorId = "VTS",
RateCode = "1",
PassengerCount = 1,
TripTime = 1140,
TripDistance = 3.75f,
PaymentType = "CRD",
FareAmount = 0 // To predict. Actual/Observed = 15.5
};
var prediction = predictionFunction.Predict(taxiTripSample);
Console.WriteLine($"**********************************************************************");
Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
Console.WriteLine($"**********************************************************************");
}
}
}
程式執行後顯示的結果:
*************************************************
* Model quality metrics evaluation
*------------------------------------------------
* R2 Score: 0.92
* RMS loss: 2.81
**********************************************************************
Predicted fare: 15.7855, actual fare: 15.5
**********************************************************************
最後的預測結果還是比較符合實際數值的。