Alink漫談(四) : 模型的來龍去脈
0x00 摘要
Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文將從模型角度入手帶領大家來再次深入Alink。
因為Alink的公開資料太少,所以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會隨時更新。
0x01 模型
之前的文章中,我們一直沒有仔細說明Alink的模型,本篇我們就深入探究一下。套用下範偉的話:我既想知道模型是怎麼來的,我又想知道模型是怎麼沒的。
1.1 模型包含內容
我們先想想,一個機器學習訓練出來的模型,應該包含哪些內容。
- 流水線:因為一個模型可能包括多個階段,比如轉化,預測等,這樣構成了一個流水線。
- 演算法 :這個是具體機器學習平臺繫結的。比如在Flink就是某一個java演算法類。
- 引數:這個是肯定要有的,機器學習很大一部分工作不就是做這個的嘛。
- 資料:這個其實也應該算引數的一種,也是訓練出來的。比如說KMeans演算法訓練出來的各個中心點。
1.2 Alink的模型檔案
讓我們開啟Alink的模型檔案來驗證下:
-1,"{""schema"":["""",""model_id BIGINT,model_info VARCHAR""],""param"":[""{\""outputCol\"":\""\\\""features\\\""\"",\""selectedCols\"":\""[\\\""sepal_length\\\"",\\\""sepal_width\\\"",\\\""petal_length\\\"",\\\""petal_width\\\""]\""}"",""{\""vectorCol\"":\""\\\""features\\\""\"",\""maxIter\"":\""100\"",\""reservedCols\"":\""[\\\""category\\\""]\"",\""k\"":\""3\"",\""predictionCol\"":\""\\\""prediction_result\\\""\"",\""predictionDetailCol\"":\""\\\""prediction_detail\\\""\""}""],""clazz"":[""com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler"",""com.alibaba.alink.pipeline.clustering.KMeansModel""]}"
1,"0^{""vectorCol"":""\""features\"""",""latitudeCol"":null,""longitudeCol"":null,""distanceType"":""\""EUCLIDEAN\"""",""k"":""3"",""vectorSize"":""4""}"
1,"1048576^{""clusterId"":0,""weight"":39.0,""vec"":{""data"":[6.8538461538461535,3.0769230769230766,5.7153846153846155,2.0538461538461545]}}"
1,"2097152^{""clusterId"":1,""weight"":61.0,""vec"":{""data"":[5.883606557377049,2.740983606557377,4.388524590163936,1.4344262295081969]}}"
1,"3145728^{""clusterId"":2,""weight"":50.0,""vec"":{""data"":[5.006,3.418,1.4640000000000002,0.24400000000000005]}}"
我們看到了兩個類名字:
com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler
com.alibaba.alink.pipeline.clustering.KMeansModel
這就是我們提到的演算法,Alink在執行過程中,可以根據這兩個類名字來生成java類。而兩個演算法類看起來是可以構建成一個流水線。我們也能看到引數和資料。
但是有幾個地方很奇怪:
- 1048576,2097152這些奇怪的數字是什麼意思?
- 為什麼檔案的第一個數值是-1?然後第二行第一個數字是 1?怎麼沒有 中間的 0 ?
- 具體Alink是如何生成和載入模型的?
下面我們就一一排查。
0x02 流程圖
我們首先給出一個流程圖便於大家理解。這個圖只是邏輯上的大致概念,和真實執行有區別。因為實際場景上是先生成執行計劃,再具體操作。
* 下面只是邏輯上的大致概念,和真實執行有區別,因為實際場景上是先生成執行計劃,再具體操作.
* 所以只是給大家一個概念。
*
*
* Pipeline.fit 訓練
* |
* |
* +-----> KMeansTrainModelData [ centroids, params -- 中心點資料,引數]
* | // KMeansOutputModel.calc()中執行,生成中心點資料和引數
* |
* |
* +-----> Tuple2<Params, Iterable<String>> [ "Params"是模型後設資料,Iterable<String>是模型具體資料 ]
* | // KMeansModelDataConverter.serializeModel(),進行序列化操作,包括 把資料轉換成json,呼叫KMeansTrainModelData.toParams設定各種引數
* |
* |
* +-----> Collector<Row> [ Row可以有任意的field,基於position(zero-based)訪問field ]
* | // ModelConverterUtils.appendMetaRow,ModelConverterUtils.appendDataRows
* |
* |
* +-----> List<Row> model [ collector.getRows() ]
* | // List<Row> model = completeResult.calc(context);
* |
* |
* +-----> DataSet<Row> [ 序列化運算元計算結果 ]
* | // BaseComQueue.exec --- serializeModel(clearObjs(loopEnd))
* |
* |
* +-----> Table output [ AlgoOperator.output,就是運算元元件的輸出表 ]
* | // KMeansTrainBatchOp.linkFrom --- setOutput
* |
* |
* +-----> KMeansModel [ 模型,Find the closest cluster center for every point ]
* | // createModel(train(input).getOutputTable()) 這裡設定模型引數
* | // KMeansModel.setModelData(Table modelData) 這裡設定模型資料
* |
* |
* +-----> TransformerBase[] [ PipelineModel.transformers ]
* | // 這就是最終訓練出來的流水線模型,KMeansModel是其中一個,KMeansModelMapper是KMeansModel的業務元件
* |
* |
* PipelineModel.save 儲存
* |
* |
* +-----> BatchOperator [ 把transformers陣列壓縮成BatchOperator ]
* | // ModelExporterUtils.packTransformersArray
* |
* |
* +-----> 儲存的模型檔案 [ csv檔案 ]
* | // PipelineModel.save --- CsvSinkBatchOp(path)
* |
* PipelineModel.load 載入
* |
* |
* +-----> 儲存的模型檔案 [ csv檔案 ]
* | // PipelineModel.load --- CsvSourceBatchOp(path)
* |
* |
* +-----> KMeansModel [ 模型,Find the closest cluster center for every point ]
* | // 依據檔案生成模型,(TransformerBase) clazz.getConstructor(Params.class)
* | // 設定資料((ModelBase) transformers[i]).setModelData(data.getOutputTable())
* |
* +-----> TransformerBase[] [ 從csv檔案讀取並恢復的transformers ]
* | // ModelExporterUtils.unpackTransformersArray(batchOp)
* |
* |
* +-----> PipelineModel [ 流水線模型 ]
* | // new PipelineModel(ModelExporterUtils.unpackTransformersArray(batchOp));
* |
* |
* PipelineModel.transform(data) 預測
* |
* |
* |
* +-----> ModelSource [ Load model data from ModelSource when open() ]
* | // ModelMapperAdapter.open --- List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext());
* |
* +-----> Tuple2<Params, Iterable<String>> [ metaAndData ]
* | // SimpleModelDataConverter.load
* |
* |
* +-----> KMeansTrainModelData [ 反序列化 ]
* | // KMeansModelDataConverter.deserializeModel(Params params, Iterable<String> data)
* |
* |
* +-----> KMeansTrainModelData [ Load KMeansTrainModelData from saved model ]
* | // KMeansModelMapper.loadModel
* | // KMeansTrainModelData.loadModelForTrain(Params params, Iterable<String> data)
* |
* |
* +-----> KMeansPredictModelData [ Model data for KMeans trainData ]
* | // 將訓練模型資料轉換為預測模型資料,裡面包含centroids
* | // KMeansUtil.transformTrainDataToPredictData(trainModelData);
* |
* |
* +-----> Row row [ "5.0,3.2,1.2,0.2,Iris-setosa,5.0 3.2 1.2 0.2" ]
* | // row是預測目標資料,ModelMapperAdapter.map
* |
* |
* +-----> Row row [ "0|0.4472728134421832 0.35775115900088217 0.19497602755693455" ]
* | // 預測結果,KMeansModelMapper.map
* |
* |
0x03 生成模型
我們還是用KMeans演算法來做示例,看看模型資料是什麼樣子,如何轉換成Alink需要的樣子。
VectorAssembler va = new VectorAssembler()
.setSelectedCols(new String[]{"sepal_length", "sepal_width", "petal_length", "petal_width"})
.setOutputCol("features");
KMeans kMeans = new KMeans().setVectorCol("features").setK(3)
.setPredictionCol("prediction_result")
.setPredictionDetailCol("prediction_detail")
.setReservedCols("category")
.setMaxIter(100);
Pipeline pipeline = new Pipeline().add(va).add(kMeans);
pipeline.fit(data);
從之前文章中大家可以知道,KMeans訓練最重要的類是KMeansTrainBatchOp。KMeansTrainBatchOp在演算法迭代結束時候,使用.closeWith(new KMeansOutputModel(distanceType, vectorColName, latitudeColName, longitudeColName))
來輸出模型。
3.1 生成模型
所以我們重點就要看 KMeansOutputModel 類。其calc函式的作用就是把中心點和引數轉化為模型。
- 首先是呼叫serializeModel將中心點序列化成json。這裡記做 (1),下面程式碼註釋會對應指出。
- 其次save函式會進行序列化,生成了一個Tuple2 <Params, Iterable
>。Params是引數,Iterable 是模型的具體資料,就是中心點的集合。這裡記做 (2),下面程式碼註釋會對應指出。 - 然後save函式把引數和資料分開儲存。這裡記做 (3),下面註釋會對應指出。
- 最後collector就是模型資料。這裡記做 (4),下面註釋會對應指出。
/**
* Tranform the centroids to KmeansModel.
*/
public class KMeansOutputModel extends CompleteResultFunction {
private DistanceType distanceType;
private String vectorColName;
private String latitudeColName;
private String longtitudeColName;
@Override
public List <Row> calc(ComContext context) {
KMeansTrainModelData modelData = new KMeansTrainModelData();
... 各種賦值操作
modelData.params = new KMeansTrainModelData.ParamSummary();
modelData.params.k = k;
modelData.params.vectorColName = vectorColName;
...
// 我們可以看出來,在此處,計算出來的中心點和各種引數已經被新增到KMeansTrainModelData之中。
modelData = {KMeansTrainModelData@11319}
centroids = {ArrayList@11327} size = 3
0 = {KMeansTrainModelData$ClusterSummary@11330}
clusterId = 0
weight = 38.0
vec = {DenseVector@11333} "6.849999999999999 3.0736842105263156 5.742105263157895 2.071052631578947"
1 = {KMeansTrainModelData$ClusterSummary@11331}
2 = {KMeansTrainModelData$ClusterSummary@11332}
params = {KMeansTrainModelData$ParamSummary@11328}
k = 3
vectorSize = 4
distanceType = {DistanceType@11287} "EUCLIDEAN"
vectorColName = "features"
latitudeColName = null
longtitudeColName = null
RowCollector collector = new RowCollector();
// save函式中將進行(1)(2)(3),後續程式碼中會具體給出(1)(2)(3)的位置
new KMeansModelDataConverter().save(modelData, collector);
// KMeansModelDataConverter實現了SimpleModelDataConverter,所以save就呼叫到了KMeansModelDataConverter.save,其呼叫serializeModel將中心點轉換jason。最後生成了一個Tuple2 <Params, Iterable<String>>
// (4) 這時候collector就是模型資料。
return collector.getRows();
// 我們能看出來,模型資料已經和模型檔案的內容有幾分相似了。裡面有之前提到的奇怪數字。
collector = {RowCollector@11321}
rows = {ArrayList@11866} size = 4
0 = {Row@11737} "0,{"vectorCol":"\"features\"","latitudeCol":null,"longitudeCol":null,"distanceType":"\"EUCLIDEAN\"","k":"3","vectorSize":"4"}"
1 = {Row@11801} "1048576,{"clusterId":0,"weight":38.0,"vec":{"data":[6.849999999999999,3.0736842105263156,5.742105263157895,2.071052631578947]}}"
2 = {Row@11868} "2097152,{"clusterId":1,"weight":50.0,"vec":{"data":[5.006,3.4179999999999997,1.4640000000000002,0.24400000000000002]}}"
3 = {Row@11869} "3145728,{"clusterId":2,"weight":62.0,"vec":{"data":[5.901612903225806,2.7483870967741937,4.393548387096773,1.4338709677419355]}}"
}
}
具體轉化是在KMeansModelDataConverter和其基類SimpleModelDataConverter中完成。首先是呼叫serializeModel將中心點序列化成json,形成了一個json列表。
/**
* KMeans Model.
* Save the id, center point and point number of clusters.
*/
public class KMeansModelDataConverter extends SimpleModelDataConverter<KMeansTrainModelData, KMeansPredictModelData> {
public KMeansModelDataConverter() {}
@Override
public Tuple2<Params, Iterable<String>> serializeModel(KMeansTrainModelData modelData) {
List <String> data = new ArrayList <>();
for (ClusterSummary centroid : modelData.centroids) {
data.add(JsonConverter.toJson(centroid)); // (1),把中心點轉換生成json
}
return Tuple2.of(modelData.params.toParams(), data);
}
@Override
public KMeansPredictModelData deserializeModel(Params params, Iterable<String> data) {
KMeansTrainModelData trainModelData = KMeansUtil.loadModelForTrain(params, data);
return KMeansUtil.transformTrainDataToPredictData(trainModelData);
}
}
其次進行序列化操作,生成Tuple2<Params, Iterable
/**
* The abstract class for a kind of {@link ModelDataConverter} where the model data can serialize to
* "Tuple2&jt;Params, Iterable&jt;String>>". Here "Params" is the meta data of the model, and "Iterable&jt;String>" is
* concrete data of the model.
*/
public abstract class SimpleModelDataConverter<M1, M2> implements ModelDataConverter<M1, M2> {
@Override
public M2 load(List<Row> rows) {
Tuple2<Params, Iterable<String>> metaAndData = ModelConverterUtils.extractModelMetaAndData(rows);
return deserializeModel(metaAndData.f0, metaAndData.f1);
}
@Override
public void save(M1 modelData, Collector<Row> collector) {
// (2),序列化生成Tuple2
Tuple2<Params, Iterable<String>> model = serializeModel(modelData);
// 此時模型資料是一個元祖Tuple2<Params, Iterable<String>>
model = {Tuple2@11504} "(Params {vectorCol="features", latitudeCol=null, longitudeCol=null, distanceType="EUCLIDEAN", k=3, vectorSize=4},[{"clusterId":0,"weight":38.0,"vec":{"data":[6.849999999999999,3.0736842105263156,5.742105263157895,2.071052631578947]}}, {"clusterId":1,"weight":50.0,"vec":{"data":[5.006,3.4179999999999997,1.4640000000000002,0.24400000000000002]}}, {"clusterId":2,"weight":62.0,"vec":{"data":[5.901612903225806,2.7483870967741937,4.393548387096773,1.4338709677419355]}}])"
// (3) 分開傳送引數和資料
ModelConverterUtils.appendMetaRow(model.f0, collector, 2);
ModelConverterUtils.appendDataRows(model.f1, collector, 2);
}
}
然後分開儲存引數和資料。
/**
* Collector of Row type data.
*/
public class RowCollector implements Collector<Row> {
private List<Row> rows;
@Override
public void collect(Row row) {
rows.add(row); // 把資料儲存起來
}
}
// 呼叫棧是
collect:37, RowCollector (com.alibaba.alink.common.utils)
collect:12, RowCollector (com.alibaba.alink.common.utils)
appendStringData:270, ModelConverterUtils (com.alibaba.alink.common.model)
appendMetaRow:35, ModelConverterUtils (com.alibaba.alink.common.model)
save:57, SimpleModelDataConverter (com.alibaba.alink.common.model)
calc:76, KMeansOutputModel (com.alibaba.alink.operator.common.clustering.kmeans)
mapPartition:287, BaseComQueue$4 (com.alibaba.alink.common.comqueue)
3.2 轉換DataSet
模型資料是要轉換成 DataSet,即 a collection of rows。其轉換目的是為了讓模型資料在Alink中更好的傳輸和被利用。
把模型資料中的string轉換為 row資料的時候,可能會遇到string過長的問題,所以Alink就將String分割轉存為多行row。這時候就用ModelConverterUtils的getModelId,getStringIndex函式來分割。
這時候得到的model Id就是計算出來的1048576,就是模型檔案中的那個奇怪數字。
後續load模型時候也會用同樣思路從row轉換回模型string。
// A utility class for converting model data to a collection of rows.
class ModelConverterUtils {
/**
* Maximum number of slices a string can split to.
*/
static final long MAX_NUM_SLICES = 1024L * 1024L;
private static long getModelId(int stringIndex, int sliceIndex) {
return MAX_NUM_SLICES * stringIndex + sliceIndex;
}
private static int getStringIndex(long modelId) {
return (int) ((modelId) / MAX_NUM_SLICES);
}
}
row = {Row@11714} "1048576,{"clusterId":0,"weight":62.0,"vec":{"data":[5.901612903225806,2.7483870967741932,4.393548387096773,1.4338709677419355]}}"
fields = {Object[2]@11724}
0 = {Long@11725} 1048576
1 = "{"clusterId":0,"weight":62.0,"vec":{"data":[5.901612903225806,2.7483870967741932,4.393548387096773,1.4338709677419355]}}"
// 相關呼叫棧如下
appendStringData:270, ModelConverterUtils (com.alibaba.alink.common.model)
appendDataRows:52, ModelConverterUtils (com.alibaba.alink.common.model)
save:58, SimpleModelDataConverter (com.alibaba.alink.common.model)
calc:76, KMeansOutputModel (com.alibaba.alink.operator.common.clustering.kmeans)
mapPartition:287, BaseComQueue$4 (com.alibaba.alink.common.comqueue)
run:103, MapPartitionDriver (org.apache.flink.runtime.operators)
...
run:748, Thread (java.lang)
3.3 儲存為Table
前面KMeansOutputModel最終返回的是一個DataSet,這裡將把這個DataSet轉化為Table儲存在流水線中。
public final class KMeansTrainBatchOp extends BatchOperator <KMeansTrainBatchOp>
public KMeansTrainBatchOp linkFrom(BatchOperator <?>... inputs) {
DataSet <Row> finalCentroid = iterateICQ(initCentroid, data,
vectorSize, maxIter, tol, distance, distanceType, vectorColName, null, null);
// 這裡儲存為Table
this.setOutput(finalCentroid, new KMeansModelDataConverter().getModelSchema());
return this;
}
this = {KMeansTrainBatchOp@5130} "UnnamedTable$1"
params = {Params@5143} "Params {vectorCol="features", maxIter=100, reservedCols=["category"], k=3, predictionCol="prediction_result", predictionDetailCol="prediction_detail"}"
output = {TableImpl@5188} "UnnamedTable$1"
tableEnvironment = {BatchTableEnvironmentImpl@5190}
operationTree = {DataSetQueryOperation@5191}
operationTreeBuilder = {OperationTreeBuilder@5192}
lookupResolver = {LookupCallResolver@5193}
tableName = "UnnamedTable$1"
sideOutputs = null
我們可以看到,在Alink執行時候,模型資料都統一轉化為Table型別。這部分原因可能是因為Alink想要統一處理DataSet和DataStream,即批和流都要用一個思路或者程式碼來處理。而Flink目前已經用Table來統一整合二者,所以Alink就針對此統一用Table。參見如下:
public abstract class ModelBase<M extends ModelBase<M>> extends TransformerBase<M>
implements Model<M> {
protected Table modelData;
}
public abstract class AlgoOperator<T extends AlgoOperator<T>>
implements WithParams<T>, HasMLEnvironmentId<T>, Serializable {
// Params for algorithms.
private Params params;
// The table held by operator.
private Table output = null;
// The side outputs of operator that be similar to the stream's side outputs.
private Table[] sideOutputs = null;
}
0x04 儲存模型
4.1 儲存程式碼
我們修改一下程式碼,呼叫save函式把流水線模型儲存起來。Alink目前是把模型檔案儲存成特殊格式的csv檔案。
Pipeline pipeline = new Pipeline().add(va).add(kMeans);
pipeline.fit(data).save("./kmeans.csv");
流水線儲存程式碼如下:
public class PipelineModel extends ModelBase<PipelineModel> implements LocalPredictable {
// Pack the pipeline model to a BatchOperator.
public BatchOperator save() {
return ModelExporterUtils.packTransformersArray(transformers);
}
}
我們可以看到,流水線最終呼叫到 ModelExporterUtils.packTransformersArray,所以我們就重點看看這個函式。這裡可以解答模型檔案中的問題:為什麼第一個數值是-1?然後是 1?怎麼沒有 中間的 0 ?
模型檔案中每行第一個數字對應的是transformer的index。config是特殊的所以index設定為-1,下面程式碼中有指出。
模型檔案中的1 就是說明第二個transformer KMeansModel具有資料,具體資料內容就在index 1對應這行 。
為什麼模型檔案沒有 0 就是因為第一個transformer VectorAssembler沒有自己的資料,所以就不包括了。
class ModelExporterUtils {
//Pack an array of transformers to a BatchOperator.
static BatchOperator packTransformersArray(TransformerBase[] transformers) {
int numTransformers = transformers.length;
String[] clazzNames = new String[numTransformers];
String[] params = new String[numTransformers];
String[] schemas = new String[numTransformers];
for (int i = 0; i < numTransformers; i++) {
clazzNames[i] = transformers[i].getClass().getCanonicalName();
params[i] = transformers[i].getParams().toJson();
schemas[i] = "";
if (transformers[i] instanceof PipelineModel) {
schemas[i] = CsvUtil.schema2SchemaStr(PIPELINE_MODEL_SCHEMA);
} else if (transformers[i] instanceof ModelBase) {
long envId = transformers[i].getMLEnvironmentId();
BatchOperator data = BatchOperator.fromTable(((ModelBase) transformers[i]).getModelData());
data.setMLEnvironmentId(envId);
data = data.link(new VectorSerializeBatchOp().setMLEnvironmentId(envId));
schemas[i] = CsvUtil.schema2SchemaStr(data.getSchema());
}
}
Map<String, Object> config = new HashMap<>();
config.put("clazz", clazzNames);
config.put("param", params);
config.put("schema", schemas);
// 這裡就對應著模型檔案的第一個數值 -1,就是config對應的index就是-1。
Row row = Row.of(-1L, JsonConverter.toJson(config));
// 這個時候我們可以看到,schema, param, clazz 就是對應著模型檔案中的輸出,我們距離目標更近了一步
config = {HashMap@5432} size = 3
"schema" -> {String[2]@5431}
key = "schema"
value = {String[2]@5431}
0 = ""
1 = "model_id BIGINT,model_info VARCHAR"
"param" -> {String[2]@5430}
key = "param"
value = {String[2]@5430}
0 = "{"outputCol":"\"features\"","selectedCols":"[\"sepal_length\",\"sepal_width\",\"petal_length\",\"petal_width\"]"}"
1 = "{"vectorCol":"\"features\"","maxIter":"100","reservedCols":"[\"category\"]","k":"3","predictionCol":"\"prediction_result\"","predictionDetailCol":"\"prediction_detail\""}"
"clazz" -> {String[2]@5429}
key = "clazz"
value = {String[2]@5429}
0 = "com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler"
1 = "com.alibaba.alink.pipeline.clustering.KMeansModel"
BatchOperator packed = new MemSourceBatchOp(Collections.singletonList(row), PIPELINE_MODEL_SCHEMA)
.setMLEnvironmentId(transformers.length > 0 ? transformers[0].getMLEnvironmentId() :
MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID);
for (int i = 0; i < numTransformers; i++) {
BatchOperator data = null;
final long envId = transformers[i].getMLEnvironmentId();
if (transformers[i] instanceof PipelineModel) {
data = packTransformersArray(((PipelineModel) transformers[i]).transformers);
} else if (transformers[i] instanceof ModelBase) {
data = BatchOperator.fromTable(((ModelBase) transformers[i]).getModelData())
.setMLEnvironmentId(envId);
data = data.link(new VectorSerializeBatchOp().setMLEnvironmentId(envId));
}
if (data != null) {
// 這對應模型檔案中的1, 為什麼模型檔案沒有 0就是因為VectorAssembler沒有自己的資料,所以就不包括了。
packed = new UnionAllBatchOp().setMLEnvironmentId(envId).linkFrom(packed, packBatchOp(data, i));
}
}
return packed;
}
}
0x05 讀取模型
下面程式碼作用是:讀取模型,然後進行轉換。
BatchOperator data = new CsvSourceBatchOp().setFilePath(URL).setSchemaStr(SCHEMA_STR);
PipelineModel pipeline = PipelineModel.load("./kmeans.csv");
pipeline.transform(data).print();
讀取模型檔案,然後轉換成PipelineModel。
public class PipelineModel extends ModelBase<PipelineModel> implements LocalPredictable {
//Load the pipeline model from a path.
public static PipelineModel load(String path) {
return load(new CsvSourceBatchOp(path, PIPELINE_MODEL_SCHEMA));
}
//Load the pipeline model from a BatchOperator.
public static PipelineModel load(BatchOperator batchOp) {
return new PipelineModel(ModelExporterUtils.unpackTransformersArray(batchOp));
}
public PipelineModel(TransformerBase[] transformers) {
super(null);
if (null == transformers) {
this.transformers = new TransformerBase[]{};
} else {
List<TransformerBase> flattened = new ArrayList<>();
flattenTransformers(transformers, flattened);
this.transformers = flattened.toArray(new TransformerBase[0]);
}
}
}
// 相關呼叫棧如下
unpackTransformersArray:91, ModelExporterUtils (com.alibaba.alink.pipeline)
load:149, PipelineModel (com.alibaba.alink.pipeline)
load:142, PipelineModel (com.alibaba.alink.pipeline)
main:22, KMeansExample2 (com.alibaba.alink)
以下是為匯入匯出用到的功能類,比如匯入匯出transformer。我們能夠看到大致功能如下:
- 從index為-1處獲取配置資訊。
- 從配置資訊中獲取了演算法類,引數,shema等資訊。
- 根據演算法類,生成所有transformer。
- 每次生成一個新transformer時候,會讀取檔案中對應行內容,unpack該行內容,生成模型對應的資料,然後賦值給transformer。注意的是,解析出來的資料被包裝成一個BatchOperator。
class ModelExporterUtils {
// Unpack transformers array from a BatchOperator.
static TransformerBase[] unpackTransformersArray(BatchOperator batchOp) {
String configStr;
try {
// 從index為-1處獲取配置資訊。
List<Row> rows = batchOp.as(new String[]{"f1", "f2"}).where("f1=-1").collect();
Preconditions.checkArgument(rows.size() == 1, "Invalid model.");
configStr = (String) rows.get(0).getField(1);
} catch (Exception e) {
throw new RuntimeException("Fail to collect model config.");
}
// 這裡從配置資訊中獲取了演算法類,引數,shema等資訊
String[] clazzNames = JsonConverter.fromJson(JsonPath.read(configStr, "$.clazz").toString(), String[].class);
String[] params = JsonConverter.fromJson(JsonPath.read(configStr, "$.param").toString(), String[].class);
String[] schemas = JsonConverter.fromJson(JsonPath.read(configStr, "$.schema").toString(), String[].class);
// 遍歷,生成所有transformer。
int numTransformers = clazzNames.length;
TransformerBase[] transformers = new TransformerBase[numTransformers];
for (int i = 0; i < numTransformers; i++) {
try {
Class clazz = Class.forName(clazzNames[i]);
transformers[i] = (TransformerBase) clazz.getConstructor(Params.class).newInstance(
Params.fromJson(params[i])
.set(HasMLEnvironmentId.ML_ENVIRONMENT_ID, batchOp.getMLEnvironmentId()));
} catch (Exception e) {
throw new RuntimeException("Fail to re construct transformer.", e);
}
BatchOperator packed = batchOp.as(new String[]{"f1", "f2"}).where("f1=" + i);
if (transformers[i] instanceof PipelineModel) {
BatchOperator data = unpackBatchOp(packed, CsvUtil.schemaStr2Schema(schemas[i]));
transformers[i] = new PipelineModel(unpackTransformersArray(data))
.setMLEnvironmentId(batchOp.getMLEnvironmentId());
} else if (transformers[i] instanceof ModelBase) {
BatchOperator data = unpackBatchOp(packed, CsvUtil.schemaStr2Schema(schemas[i]));
// 這裡會設定模型資料。
((ModelBase) transformers[i]).setModelData(data.getOutputTable());
}
}
return transformers;
}
}
最後生成的transformers如下:
transformers = {TransformerBase[2]@9340}
0 = {VectorAssembler@9383}
mapperBuilder = {VectorAssembler$lambda@9385}
params = {Params@9386} "Params {outputCol="features", selectedCols=["sepal_length","sepal_width","petal_length","petal_width"], MLEnvironmentId=0}"
1 = {KMeansModel@9384}
mapperBuilder = {KMeansModel$lambda@9388}
modelData = {TableImpl@9389} "UnnamedTable$1"
params = {Params@9390} "Params {vectorCol="features", maxIter=100, reservedCols=["category"], k=3, MLEnvironmentId=0, predictionCol="prediction_result", predictionDetailCol="prediction_detail"}"
0x06 預測
pipeline.transform(data).print();
是預測的程式碼。
6.1 生成runtime rapper
預測演算法需要被包裝成RichMapFunction,才能夠被Flink引用。
VectorAssembler是起到轉換csv檔案作用。KMeansModel是用來預測。預測時候會呼叫到KMeansModel.transform,其又會呼叫到linkFrom,這裡生成了runtime rapper。
public abstract class MapModel<T extends MapModel<T>>
extends ModelBase<T> implements LocalPredictable {
@Override
public BatchOperator transform(BatchOperator input) {
return new ModelMapBatchOp(this.mapperBuilder, this.params)
.linkFrom(BatchOperator.fromTable(this.getModelData())
.setMLEnvironmentId(input.getMLEnvironmentId()), input);
}
}
// this.getModelData()是模型資料,對應linkFrom的輸入引數inputs[0]
// input 這個是待處理的資料。,對應linkFrom的輸入引數inputs[1]
// 模型資料就是之前從csv中取出來設定的。
public abstract class ModelBase<M extends ModelBase<M>> extends TransformerBase<M>
implements Model<M> {
public Table getModelData() {
return this.modelData;
}
}
ModelMapBatchOp.linkFrom 程式碼中,會生成ModelMapperAdapter。此時會把模型資訊作為廣播變數存起來。這樣在後續預測時候就可以先load模型資料。
public class ModelMapBatchOp<T extends ModelMapBatchOp<T>> extends BatchOperator<T> {
private static final String BROADCAST_MODEL_TABLE_NAME = "broadcastModelTable";
// (modelScheme, dataSchema, params) -> ModelMapper
private final TriFunction<TableSchema, TableSchema, Params, ModelMapper> mapperBuilder;
public ModelMapBatchOp(TriFunction<TableSchema, TableSchema, Params, ModelMapper> mapperBuilder, Params params) {
super(params);
this.mapperBuilder = mapperBuilder;
}
@Override
public T linkFrom(BatchOperator<?>... inputs) {
BroadcastVariableModelSource modelSource = new BroadcastVariableModelSource(BROADCAST_MODEL_TABLE_NAME);
ModelMapper mapper = this.mapperBuilder.apply(
inputs[0].getSchema(),
inputs[1].getSchema(),
this.getParams());
DataSet<Row> modelRows = inputs[0].getDataSet().rebalance();
// 這裡會廣播變數
DataSet<Row> resultRows = inputs[1].getDataSet()
.map(new ModelMapperAdapter(mapper, modelSource))
.withBroadcastSet(modelRows, BROADCAST_MODEL_TABLE_NAME);
TableSchema outputSchema = mapper.getOutputSchema();
this.setOutput(resultRows, outputSchema);
return (T) this;
}
}
6.2 載入模型
當預測時候,ModelMapperAdapter會在open函式先載入模型。
public class ModelMapperAdapter extends RichMapFunction<Row, Row> implements Serializable {
@Override
public void open(Configuration parameters) throws Exception {
List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext());
this.mapper.loadModel(modelRows);
}
}
// 載入出來的模型資料舉例如下
modelRows = {ArrayList@10100} size = 4
0 = {Row@10103} "2097152,{"clusterId":1,"weight":62.0,"vec":{"data":[5.901612903225806,2.7483870967741932,4.393548387096773,1.4338709677419355]}}"
1 = {Row@10104} "0,{"vectorCol":"\"features\"","latitudeCol":null,"longitudeCol":null,"distanceType":"\"EUCLIDEAN\"","k":"3","vectorSize":"4"}"
2 = {Row@10105} "3145728,{"clusterId":2,"weight":50.0,"vec":{"data":[5.005999999999999,3.418,1.4639999999999997,0.24400000000000002]}}"
3 = {Row@10106} "1048576,{"clusterId":0,"weight":38.0,"vec":{"data":[6.85,3.0736842105263156,5.742105263157894,2.0710526315789477]}}"
this.mapper.loadModel(modelRows) 會呼叫KMeansModelMapper.loadModel,其最後呼叫到
- ModelConverterUtils.extractModelMetaAndData 來進行反序列化,把DataSet轉換回Tuple。
- 最終呼叫到KMeansUtil.KMeansTrainModelData生成用來預測的模型KMeansTrainModelData
/**
* The abstract class for a kind of {@link ModelDataConverter} where the model data can serialize to "Tuple2&jt;Params, Iterable&jt;String>>". Here "Params" is the meta data of the model, and "Iterable&jt;String>" is concrete data of the model.
*/
public abstract class SimpleModelDataConverter<M1, M2> implements ModelDataConverter<M1, M2> {
@Override
public M2 load(List<Row> rows) {
Tuple2<Params, Iterable<String>> metaAndData = ModelConverterUtils.extractModelMetaAndData(rows);
return deserializeModel(metaAndData.f0, metaAndData.f1);
}
}
metaAndData = {Tuple2@10267} "(Params {vectorCol="features", latitudeCol=null, longitudeCol=null, distanceType="EUCLIDEAN", k=3, vectorSize=4},com.alibaba.alink.common.model.ModelConverterUtils$StringDataIterable@7e9c1b42)"
f0 = {Params@10252} "Params {vectorCol="features", latitudeCol=null, longitudeCol=null, distanceType="EUCLIDEAN", k=3, vectorSize=4}"
params = {HashMap@10273} size = 6
"vectorCol" -> ""features""
"latitudeCol" -> null
"longitudeCol" -> null
"distanceType" -> ""EUCLIDEAN""
"k" -> "3"
"vectorSize" -> "4"
f1 = {ModelConverterUtils$StringDataIterable@10262}
iterator = {ModelConverterUtils$StringDataIterator@10272}
modelRows = {ArrayList@10043} size = 4
order = {Integer[4]@10388}
curr = "{"clusterId":0,"weight":38.0,"vec":{"data":[6.85,3.0736842105263156,5.742105263157894,2.0710526315789477]}}"
listPos = 2
可以看到getModelRows就是從廣播變數中讀取資料。
public class BroadcastVariableModelSource implements ModelSource {
public List<Row> getModelRows(RuntimeContext runtimeContext) {
return runtimeContext.getBroadcastVariable(modelVariableName);
}
}
6.3 預測
最後預測是在ModelMapperAdapter的map函式。這實際上是 flink根據使用者程式碼生成的執行計劃進行相應處理後自己執行的。
/**
* Adapt a {@link ModelMapper} to run within flink.
* <p>
* This adapter class hold the target {@link ModelMapper} and it's {@link ModelSource}. Upon open(),
* it will load model rows from {@link ModelSource} into {@link ModelMapper}.
*/
public class ModelMapperAdapter extends RichMapFunction<Row, Row> implements Serializable {
@Override
public Row map(Row row) throws Exception {
return this.mapper.map(row);
}
}
mapper實際呼叫到KMeansModelMapper,這裡就用到了模型資料。
// Find the closest cluster center for every point.
public class KMeansModelMapper extends ModelMapper {
@Override
public Row map(Row row){
Vector record = KMeansUtil.getKMeansPredictVector(colIdx, row);
......
if(isPredDetail){
double[] probs = KMeansUtil.getProbArrayFromDistanceArray(clusterDistances);
DenseVector vec = new DenseVector(probs.length);
for(int i = 0; i < this.modelData.params.k; i++){
// 這裡就用到了模型資料進行預測
vec.set((int)this.modelData.getClusterId(i), probs[i]);
}
res.add(vec.toString());
}
return outputColsHelper.getResultRow(row, Row.of(res.toArray(new Object[0])));
}
}
// 模型資料如下
this = {KMeansModelMapper@10822}
modelData = {KMeansPredictModelData@10828}
centroids = {FastDistanceMatrixData@10842}
vectors = {DenseMatrix@10843} "mat[4,3]:\n 5.006,6.85,5.901612903225807\n 3.418,3.0736842105263156,2.7483870967741937\n 1.4639999999999997,5.742105263157894,4.393548387096774\n 0.24400000000000002,2.0710526315789473,1.4338709677419355\n"
label = {DenseMatrix@10844} "mat[1,3]:\n 38.945592000000005,93.63106648199445,63.74191987513008\n"
rows = {Row[3]@10845}
params = {KMeansTrainModelData$ParamSummary@10829}
k = 3
vectorSize = 4
distanceType = {DistanceType@10849} "EUCLIDEAN"
vectorColName = "features"
latitudeColName = null
longtitudeColName = null
0x07 流式預測
我們知道Alink是可以支援批式預測和流式預測。我們看看流式預測是怎麼處理的。下面就是KMeans的流式預測。
public class KMeansExampleStream {
AlgoOperator getData(boolean isBatch) {
Row[] array = new Row[] {
Row.of(0, "0 0 0"),
Row.of(1, "0.1,0.1,0.1"),
Row.of(2, "0.2,0.2,0.2"),
Row.of(3, "9 9 9"),
Row.of(4, "9.1 9.1 9.1"),
Row.of(5, "9.2 9.2 9.2")
};
if (isBatch) {
return new MemSourceBatchOp(
Arrays.asList(array), new String[] {"id", "vec"});
} else {
return new MemSourceStreamOp(
Arrays.asList(array), new String[] {"id", "vec"});
}
}
public static void main(String[] args) throws Exception {
KMeansExampleStream ks = new KMeansExampleStream();
BatchOperator inOp1 = (BatchOperator)ks.getData(true);
StreamOperator inOp2 = (StreamOperator)ks.getData(false);
KMeansTrainBatchOp trainBatch = new KMeansTrainBatchOp().setVectorCol("vec").setK(2);
KMeansPredictBatchOp predictBatch = new KMeansPredictBatchOp().setPredictionCol("pred");
trainBatch.linkFrom(inOp1);
KMeansPredictStreamOp predictStream = new KMeansPredictStreamOp(trainBatch).setPredictionCol("pred");
predictStream.linkFrom(inOp2);
predictStream.print(-1,5);
StreamOperator.execute();
}
}
predictStream.linkFrom
是我們這裡的要點,其呼叫到ModelMapStreamOp。ModelMapStreamOp這個類的程式碼雖然少,但是條理非常清晰,非常適合學習。
- 首先相關繼承關係如下
KMeansPredictStreamOp extends ModelMapStreamOp
- 其次能看出來,流預測所依賴的資料模型依然是一個批處理產生的模型
BatchOperator model
。 - mapperBuilder是業務模型運算元,其構造是通過(modelScheme, dataSchema, params) 得出來的,這恰恰就是機器學習的幾個要素。
- KMeansModelMapper就是具體模型運算元 :
KMeansModelMapper extends ModelMapper
。
// Find the closest cluster center for every point.
public final class KMeansPredictStreamOp extends ModelMapStreamOp <KMeansPredictStreamOp>
implements KMeansPredictParams <KMeansPredictStreamOp> {
// @param model trained from kMeansBatchOp
public KMeansPredictStreamOp(BatchOperator model) {
this(model, new Params());
}
public KMeansPredictStreamOp(BatchOperator model, Params params) {
super(model, KMeansModelMapper::new, params);
}
}
具體深入程式碼 ,我們可以看到:
- 首先 ,把DataSet的資料一次性都取出來,因為都取出來容易造成記憶體問題,所以 DataSet.collect 註釋中有警告:Convenience method to get the elements of a DataSet as a List. As DataSet can contain a lot of data, this method should be used with caution.
- 其次,通過如下程式碼
this.mapperBuilder.apply(modelSchema, in.getSchema(), this.getParams());
構建業務模型KMeansModelMapper。 - 然後,
new ModelMapperAdapter(mapper, modelSource)
會建立一個 RichFunction 作為執行適配層。 - 最後,輸入的流資料來源 in 會通過
in.getDataStream().map((new ModelMapperAdapter(mapper, modelSource));
來完成預測。 - 實際上,這時候只是生成stream graph,具體計算是後續flink會根據graph再進行處理。
public class ModelMapStreamOp<T extends ModelMapStreamOp <T>> extends StreamOperator<T> {
private final BatchOperator model;
// (modelScheme, dataSchema, params) -> ModelMapper
private final TriFunction<TableSchema, TableSchema, Params, ModelMapper> mapperBuilder;
public ModelMapStreamOp(BatchOperator model,
TriFunction<TableSchema, TableSchema, Params, ModelMapper> mapperBuilder,
Params params) {
super(params);
this.model = model;
this.mapperBuilder = mapperBuilder;
}
@Override
public T linkFrom(StreamOperator<?>... inputs) {
StreamOperator<?> in = checkAndGetFirst(inputs);
TableSchema modelSchema = this.model.getSchema();
try {
// 把模型資料全都取出來
DataBridge modelDataBridge = DirectReader.collect(model);
DataBridgeModelSource modelSource = new DataBridgeModelSource(modelDataBridge);
ModelMapper mapper = this.mapperBuilder.apply(modelSchema, in.getSchema(), this.getParams());
// 生成runtime適配層和預測運算元。把預測結果返回。
// 實際上,這時候只是生成stream graph,具體計算是後續flink會根據graph再進行處理。
DataStream <Row> resultRows = in.getDataStream().map(new ModelMapperAdapter(mapper, modelSource));
TableSchema resultSchema = mapper.getOutputSchema();
this.setOutput(resultRows, resultSchema);
return (T) this;
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
}
0x08 總結
現在我們已經梳理了Alink模型的來龍去脈,讓我們再次拿出模型檔案內容來驗證。
- 第一行是後設資料資訊,其中包含schema, 演算法類名稱,元引數。Alink可以通過這些資訊生成流水線的transformer。
- 後續行是演算法類所需要的模型資料。每一行對應一個演算法類。Alink會取出這些資料來設定到transformer中。
- 後續行的模型資料是具體演算法相關。
- 第一行特殊之處在於其index是 -1。後續資料行的index從0開始,如果某一個transformer沒有資料,則沒有對應行,跳過index。
這樣Alink就可以根據模型檔案生成流水線模型。
-1,"{""schema"":["""",""model_id BIGINT,model_info VARCHAR""],""param"":[""{\""outputCol\"":\""\\\""features\\\""\"",\""selectedCols\"":\""[\\\""sepal_length\\\"",\\\""sepal_width\\\"",\\\""petal_length\\\"",\\\""petal_width\\\""]\""}"",""{\""vectorCol\"":\""\\\""features\\\""\"",\""maxIter\"":\""100\"",\""reservedCols\"":\""[\\\""category\\\""]\"",\""k\"":\""3\"",\""predictionCol\"":\""\\\""prediction_result\\\""\"",\""predictionDetailCol\"":\""\\\""prediction_detail\\\""\""}""],""clazz"":[""com.alibaba.alink.pipeline.dataproc.vector.VectorAssembler"",""com.alibaba.alink.pipeline.clustering.KMeansModel""]}"
1,"0^{""vectorCol"":""\""features\"""",""latitudeCol"":null,""longitudeCol"":null,""distanceType"":""\""EUCLIDEAN\"""",""k"":""3"",""vectorSize"":""4""}"
1,"1048576^{""clusterId"":0,""weight"":39.0,""vec"":{""data"":[6.8538461538461535,3.0769230769230766,5.7153846153846155,2.0538461538461545]}}"
1,"2097152^{""clusterId"":1,""weight"":61.0,""vec"":{""data"":[5.883606557377049,2.740983606557377,4.388524590163936,1.4344262295081969]}}"
1,"3145728^{""clusterId"":2,""weight"":50.0,""vec"":{""data"":[5.006,3.418,1.4640000000000002,0.24400000000000005]}}"