Alink漫談(十三) :線上學習演算法FTRL 之 具體實現
0x00 摘要
Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文和上文一起介紹了線上學習演算法 FTRL 在Alink中是如何實現的,希望對大家有所幫助。
0x01 回顧
書接上回 Alink漫談(十二) :線上學習演算法FTRL 之 整體設計 。到目前為止,已經處理完畢輸入,接下來就是線上訓練。訓練優化的主要目標是找到一個方向,引數朝這個方向移動之後使得損失函式的值能夠減小,這個方向往往由一階偏導或者二階偏導各種組合求得。
為了讓大家更好理解,我們再次貼出整體流程圖:
0x02 線上訓練
線上訓練主要邏輯是:
- 1)載入初始化模型到 dataBridge;dataBridge = DirectReader.collect(model);
- 2)獲取相關引數。比如vectorSize預設是30000,是否 hasInterceptItem;
- 3)獲取切分資訊。splitInfo = getSplitInfo(featureSize, hasInterceptItem, parallelism); 下面馬上會用到。
- 4)切分高維向量。初始化資料做了特徵雜湊,會產生高維向量,這裡需要進行切割。 initData.flatMap(new SplitVector(splitInfo, hasInterceptItem, vectorSize,vectorTrainIdx, featureIdx, labelIdx));
- 5)構建一個 IterativeStream.ConnectedIterativeStreams iteration,這樣會構建(或者說連線)兩個資料流:反饋流和訓練流;
- 6)用iteration來構建迭代體 iterativeBody,其包括兩部分:CalcTask,ReduceTask;
- 6.1)CalcTask分成兩個部分。flatMap1 是分佈計算FTRL迭代需要的predict,flatMap2 是FTRL的更新引數部分;
- 6.2)ReduceTask分為兩個功能:“歸併這些predict計算結果“ / ”如果滿足條件則歸併模型 & 向下遊運算元輸出模型“;
- 7)result = iterativeBody.filter;基本是以時間間隔為標準來判斷(也可以認為是時間驅動),"時間未過期&向量有意義" 的資料將被髮送回反饋資料流,繼續迭代,回到步驟 6),進入flatMap2;
- 8)output = iterativeBody.filter;符合標準(時間過期了)的資料將跳出迭代,然後演算法會呼叫WriteModel將LineModelData轉換為多條Row,轉發給下游operator(也就是線上預測階段);即定時把模型更新給線上預測階段。
2.1 預置模型
前面說到,FTRL先要訓練出一個邏輯迴歸模型作為FTRL演算法的初始模型,這是為了系統冷啟動的需要。
2.1.1 訓練模型
具體邏輯迴歸模型設定/訓練是 :
// train initial batch model
LogisticRegressionTrainBatchOp lr = new LogisticRegressionTrainBatchOp()
.setVectorCol(vecColName)
.setLabelCol(labelColName)
.setWithIntercept(true)
.setMaxIter(10);
BatchOperator<?> initModel = featurePipelineModel.transform(trainBatchData).link(lr);
訓練好之後,模型資訊是DataSet
2.1.2 載入模型
FtrlTrainStreamOp將initModel作為初始化引數。
FtrlTrainStreamOp model = new FtrlTrainStreamOp(initModel)
在FtrlTrainStreamOp建構函式中會載入這個模型;
dataBridge = DirectReader.collect(initModel);
具體載入時通過MemoryDataBridge直接獲取初始化模型DataSet中的資料。
public MemoryDataBridge generate(BatchOperator batchOperator, Params globalParams) {
return new MemoryDataBridge(batchOperator.collect());
}
2.2 分割高維向量
從前文可知,Alink的FTRL演算法設定的特徵向量維度是30000。所以演算法第一步就是切分高維度向量,以便分散式計算。
String vecColName = "vec";
int numHashFeatures = 30000;
首先要獲取切分資訊,程式碼如下,就是將特徵數目featureSize 除以 並行度parallelism,然後得到了每個task對應係數的初始位置。
private static int[] getSplitInfo(int featureSize, boolean hasInterceptItem, int parallelism) {
int coefSize = (hasInterceptItem) ? featureSize + 1 : featureSize;
int subSize = coefSize / parallelism;
int[] poses = new int[parallelism + 1];
int offset = coefSize % parallelism;
for (int i = 0; i < offset; ++i) {
poses[i + 1] = poses[i] + subSize + 1;
}
for (int i = offset; i < parallelism; ++i) {
poses[i + 1] = poses[i] + subSize;
}
return poses;
}
//程式執行時變數如下
featureSize = 30000
hasInterceptItem = true
parallelism = 4
coefSize = 30001
subSize = 7500
poses = {int[5]@11660}
0 = 0
1 = 7501
2 = 15001
3 = 22501
4 = 30001
offset = 1
然後根據切分資訊對高維向量進行切割。
// Tuple5<SampleId, taskId, numSubVec, SubVec, label>
DataStream<Tuple5<Long, Integer, Integer, Vector, Object>> input
= initData.flatMap(new SplitVector(splitInfo, hasInterceptItem, vectorSize,
vectorTrainIdx, featureIdx, labelIdx))
.partitionCustom(new CustomBlockPartitioner(), 1);
具體切分在SplitVector.flatMap函式完成,結果就是把一個高維度向量分割給各個CalcTask。
程式碼摘要如下:
public void flatMap(Row row, Collector<Tuple5<Long, Integer, Integer, Vector, Object>> collector) throws Exception {
long sampleId = counter;
counter += parallelism;
Vector vec;
if (vectorTrainIdx == -1) {
.....
} else {
// 輸入row的第vectorTrainIdx個field就是那個30000大小的係數向量
vec = VectorUtil.getVector(row.getField(vectorTrainIdx));
}
if (vec instanceof SparseVector) {
Map<Integer, Vector> tmpVec = new HashMap<>();
for (int i = 0; i < indices.length; ++i) {
.....
// 此處迭代完成後,tmpVec中就是task number個元素,每一個元素是分割好的係數向量。
}
for (Integer key : tmpVec.keySet()) {
//此處遍歷,給後面所有CalcTask傳送五元組資料。
collector.collect(Tuple5.of(sampleId, key, subNum, tmpVec.get(key), row.getField(labelIdx)));
}
} else {
......
}
}
}
這個Tuple5.of(sampleId, key, subNum, tmpVec.get(key), row.getField(labelIdx) )就是後面CalcTask的輸入。
2.3 迭代訓練
此處理論上有以下幾個重點:
-
預測方法:在每一輪t中,針對特徵樣本xt,以及迭代後(第一次則是給定初值)的模型引數wt,我們可以預測該樣本的標記值:pt=σ(wt,xt),其中σ(a)=1/(1+exp(−a))是一個sigmoid函式。
-
損失函式:對一個特徵樣本xt,其對應的標記為yt ∈ 0,1,則通過 logistic loss 來作為損失函式。
-
迭代公式:我們的目的是使得損失函式儘可能的小,即可以採用極大似然估計來求解引數。首先求梯度,然後使用FTRL進行迭代。
虛擬碼思路大致如下
double p = learner.predict(x); //預測
learner.updateModel(x, p, y); //更新模型
double loss = LogLossEvalutor.calLogLoss(p, y); //計算損失
evalutor.addLogLoss(loss); //更新損失
totalLoss += loss;
trainedNum += 1;
具體實施上Alink有自己的特點和調整。
2.3.1 Flink Stream迭代功能
機器學習都需要迭代訓練,Alink這裡利用了Flink Stream的迭代功能。
IterativeStream的例項是通過DataStream的iterate方法建立的˙。iterate方法存在兩個過載形式:
- 一種是無參的,表示不限定最大等待時間;
- 一種提供一個長整型maxWaitTimeMillis引數,允許使用者指定等待反饋邊的下一個輸入元素的最大時間間隔。
Alink選擇了第二種。
在建立ConnectedIterativeStreams時候,用迭代流的初始輸入作為第一個輸入流,用反饋流作為第二個輸入。
每一種資料流(DataStream)都會有與之對應的流轉換(StreamTransformation)。IterativeStream對應的轉換是FeedbackTransformation。
迭代流(IterativeStream)對應的轉換是反饋轉換(FeedbackTransformation),它表示拓撲中的一個反饋點(也即迭代頭)。一個反饋點包含一個輸入邊以及若干個反饋邊,且Flink要求每個反饋邊的並行度必須跟輸入邊的並行度一致,這一點在往該轉換中加入反饋邊時會進行校驗。
當IterativeStream物件被構造時,FeedbackTransformation的例項會被建立並傳遞給DataStream的構造方法。
迭代的關閉是通過呼叫IterativeStream的例項方法closeWith來實現的。這個函式指定了某個流將成為迭代程式的結束,並且這個流將作為輸入的第二部分(second input)被反饋回迭代。
2.3.2 迭代構建
對於Alink來說,迭代構建程式碼是:
// train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>
// feedback format = Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>
IterativeStream.ConnectedIterativeStreams<
Tuple5<Long, Integer, Integer, Vector, Object>,
Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
iteration = input.iterate(Long.MAX_VALUE)
.withFeedbackType(TypeInformation
.of(new TypeHint<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {}));
// 即iteration是一個 IterativeStream.ConnectedIterativeStreams<...>
2.3.2.1 迭代的輸入
從程式碼和註釋可以看出,迭代的兩種輸入是:
- train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>;這種其實是訓練資料;
- Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>;這種其實是反饋資料,就是“迭代的反饋流”作為這個第二輸入 (second input);
2.3.2.2 迭代的反饋
反饋流的設定是通過呼叫IterativeStream的例項方法closeWith來實現的。Alink這裡是
DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
result = iterativeBody.filter(
return (t3.f0 > 0 && t3.f2 > 0); // 這裡是省略版本程式碼
);
iteration.closeWith(result);
前面已經提到過,result filter 的判斷是 return (t3.f0 > 0 && t3.f2 > 0)
,如果滿足條件,則說明時間未過期&向量有意義,所以此時應該反饋回去,繼續訓練。
反饋流的格式是:
- Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps>;
2.3.3 迭代體 CalcTask / ReduceTask
迭代體由兩部分構成:CalcTask / ReduceTask。
CalcTask每一個例項都擁有初始化模型dataBridge。
DataStream iterativeBody = iteration.flatMap(
new CalcTask(dataBridge, splitInfo, getParams()))
2.3.3.1 迭代初始化
迭代是由 CalcTask.open 函式開始,主要做如下幾件事
- 設定各種引數,比如
- 工作task個數,numWorkers = getRuntimeContext().getNumberOfParallelSubtasks();
- 本task的id,workerId = getRuntimeContext().getIndexOfThisSubtask();
- 讀取初始化模型
- List
modelRows = DirectReader.directRead(dataBridge);
- 把Row型別資料轉換為線性模型 LinearModelData model = new LinearModelDataConverter().load(modelRows);
- List
- 讀取本task對應的係數 coef[i - startIdx],這裡就是把整個模型切分到numWorkers這麼多的Task中,並行更新。
- 指定本task的開始時間 startTime = System.currentTimeMillis();
2.3.3.2 處理輸入資料
CalcTask.flatMap1主要實現的是FTRL演算法中的predict部分(注意,不是FTRL預測)。
解釋:pt=σ(Xt⋅w)是LR的預測函式,求出pt的唯一目的是為了求出目標函式(在LR中採用交叉熵損失函式作為目標函式)對引數w的一階導數g,gi=(pt−yt)xi。此步驟同樣適用於FTRL優化其他目標函式,唯一的不同就是求次梯度g(次梯度是左導和右導之間的集合,函式可導--左導等於右導時,次梯度就等於一階梯度)的方法不同。
函式的輸入是 "訓練輸入資料",即SplitVector.flatMap的輸出 ----> CalcCalcTask的輸入
。輸入資料是一個五元組,其格式為 train data format = <sampleId, subSampleTaskId, subNum, SparseVector(subSample), label>;
有三點需要注意:
- 是如果是第一次進入,則需要savedFristModel;
- 這裡是有輸入就處理,然後立即輸出(和flatMap2不同,flatMap2有輸入就處理,但不是立即輸出,而是當時間到期了再輸出);
- predict的實現:
((SparseVector)vec).getValues()[i] * coef[indices[i] - startIdx];
大家會說,不對!predict函式應該是 sigmoid = 1.0 / (1.0 + np.exp(-w.dot(x)))
。是的,這裡還沒有做 sigmoid 操作。當ReduceTask做了聚合之後,會把聚合好的 p 反饋回迭代體,然後在 CalcTask.flatMap2 中才會做 sigmoid 操作。
public void flatMap1(Tuple5<Long, Integer, Integer, Vector, Object> value,
Collector<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> out) throws Exception {
if (!savedFristModel) { //第一次進入需要存模型
out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
new DenseVector(coef), labelValues, -1.0, modelId++));
savedFristModel = true;
}
Long timeStamps = System.currentTimeMillis();
double wx = 0.0;
Long sampleId = value.f0;
Vector vec = value.f3;
if (vec instanceof SparseVector) {
int[] indices = ((SparseVector)vec).getIndices();
// 這裡就是具體的Predict
for (int i = 0; i < indices.length; ++i) {
wx += ((SparseVector)vec).getValues()[i] * coef[indices[i] - startIdx];
}
} else {
......
}
//處理了就輸出
out.collect(Tuple7.of(sampleId, value.f1, value.f2, value.f3, value.f4, wx, timeStamps));
}
2.3.3.3 歸併資料
ReduceTask.flatMap 負責歸併資料。
public static class ReduceTask extends
RichFlatMapFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>,
Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> {
private int parallelism;
private int[] poses;
private Map<Long, List<Object>> buffer;
private Map<Long, List<Tuple2<Integer, DenseVector>>> models = new HashMap<>();
}
flatMap函式大致完成如下功能,即兩種歸併:
- 為了輸出模型使用。判斷是否時間過期 if (value.f0 < 0),如果過期,則歸併模型:
- 生成一個List<Tuple2<Integer, DenseVector>> model = models.get(value.f6); 以value.f6,即時間戳為key,插入到HashMap中。
- 如果全部收集完成,則向下遊運算元輸出模型,並且從HashMap中刪除暫存的模型。
- 為了歸併predict使用。歸併每個CalcTask計算的predict,形成一個 lable y;
- 用 label y 更新 Tuple7的f5,即Tuple7<sampleId, subSampleTaskId, subNum, SparseVector(subSample), label, wx, timeStamps> 中的 label,也就是預測的 y。
- 給每個下游運算元(就是每個CalcTask了,不過是作為flatMap2的輸入)傳送這個新Tuple7;
當具體用作輸出模型使用時,其變數如下:
models = {HashMap@13258} size = 1
{Long@13456} 1 -> {ArrayList@13678} size = 1
key = {Long@13456} 1
value = {ArrayList@13678} size = 1
0 = {Tuple2@13698} "(1,0.0 -8.244533295515879E-5 0.0 -1.103997743166529E-4 0.0 -3.336931546279811E-5....."
2.3.3.4 判斷是否反饋
這個 filter result 是用來判斷是否反饋的。這裡t3.f0 是sampleId, t3.f2是subNum。
DataStream<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>
result = iterativeBody.filter(
new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {
@Override
public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> t3)
throws Exception {
// if t3.f0 > 0 && t3.f2 > 0 then feedback
return (t3.f0 > 0 && t3.f2 > 0);
}
});
對於 t3.f0,有兩處程式碼會設定為負值。
-
會在savedFirstModel 這裡設定一次"-1";即
if (!savedFristModel) { out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(), new DenseVector(coef), labelValues, -1.0, modelId++)); savedFristModel = true; }
-
也會在時間過期時候設定為 "-1"。
if (System.currentTimeMillis() - startTime > modelSaveTimeInterval) { startTime = System.currentTimeMillis(); out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(), new DenseVector(coef), labelValues, -1.0, modelId++)); }
對於 t3.f2,如果 subNum 大於零,說明在高維向量切分時候,是得到了有意義的數值。
因此 return (t3.f0 > 0 && t3.f2 > 0)
說明時間未過期&向量有意義,所以此時應該反饋回去,繼續訓練。
2.3.3.5 判斷是否輸出模型
這裡是filter output。
value.f0 < 0
說明時間到期了,應該輸出模型。
DataStream<Row> output = iterativeBody.filter(
new FilterFunction<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>>() {
@Override
public boolean filter(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value)
{
/* if value.f0 small than 0, then output */
return value.f0 < 0;
}
}).flatMap(new WriteModel(labelType, getVectorCol(), featureCols, hasInterceptItem));
2.3.3.6 處理反饋資料/更新引數
CalcTask.flatMap2實際完成的是FTRL演算法的其餘部分,即更新引數部分。主要邏輯如下:
- 計算時間間隔 timeInterval = System.currentTimeMillis() - value.f6;
- 正式計算predict, p = 1 / (1 + Math.exp(-p)); 即sigmoid 操作;
- 計算梯度 g = (p - label) * values[i] / Math.sqrt(timeInterval); 這裡除以了時間間隔;
- 更新引數;
- 輸入。注意,這裡是有輸入就處理,但 不是立即輸出,而是累積引數,當時間到期了再輸出,也就是做到了定期輸出模型;
在 Logistic Regression 中,sigmoid函式是σ(a) = 1 / (1 + exp(-a)) ,預估 pt = σ(xt . wt), 則 LogLoss 函式是
直接計算可以得到
具體 LR + FTRL 演算法實現如下:
@Override
public void flatMap2(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value,
Collector<Tuple7<Long, Integer, Integer, Vector, Object, Double, Long>> out)
throws Exception {
double p = value.f5;
// 計算時間間隔
long timeInterval = System.currentTimeMillis() - value.f6;
Vector vec = value.f3;
/* eta */
// 正式計算predict,之前只是計算了一半,這裡計算後半部,即
p = 1 / (1 + Math.exp(-p));
.....
if (vec instanceof SparseVector) {
// 這裡是更新引數
int[] indices = ((SparseVector)vec).getIndices();
double[] values = ((SparseVector)vec).getValues();
for (int i = 0; i < indices.length; ++i) {
// update zParam nParam
int id = indices[i] - startIdx;
// values[i]是xi
// 下面的計算基本和Google虛擬碼一致
double g = (p - label) * values[i] / Math.sqrt(timeInterval);
double sigma = (Math.sqrt(nParam[id] + g * g) - Math.sqrt(nParam[id])) / alpha;
zParam[id] += g - sigma * coef[id];
nParam[id] += g * g;
// update model coefficient
if (Math.abs(zParam[id]) <= l1) {
coef[id] = 0.0;
} else {
coef[id] = ((zParam[id] < 0 ? -1 : 1) * l1 - zParam[id])
/ ((beta + Math.sqrt(nParam[id]) / alpha + l2));
}
}
} else {
......
}
// 當時間到期了再輸出,即做到了定期輸出模型
if (System.currentTimeMillis() - startTime > modelSaveTimeInterval) {
startTime = System.currentTimeMillis();
out.collect(Tuple7.of(-1L, 0, getRuntimeContext().getIndexOfThisSubtask(),
new DenseVector(coef), labelValues, -1.0, modelId++));
}
}
2.4 輸出模型
WriteModel 類實現了輸出模型功能,大致邏輯如下:
- 生成一個LinearModelData,用訓練好的Tuple7來填充這個 LinearModelData。其中兩個重要點:
- modelData.coefVector = (DenseVector)value.f3;
- modelData.labelValues = (Object[])value.f4;
- 把模型資料轉換成List
rows。LinearModelDataConverter().save(modelData, listCollector);
- 序列化,傳送給下游運算元。因為模型可能會很大,所以這裡打散之後分佈傳送給下游運算元。
public void flatMap(Tuple7<Long, Integer, Integer, Vector, Object, Double, Long> value, Collector<Row> out){
//輸入value變數列印如下:
value = {Tuple7@13296}
f0 = {Long@13306} -1
f1 = {Integer@13307} 0
f2 = {Integer@13308} 2
f3 = {DenseVector@13309} "-0.7383426732137565 0.0 0.0 0.0 1.5885293675862715E-4 -4.834608575902742E-5 0.0 0.0 -6.754208708318647E-5 ......"
data = {double[30001]@13314}
f4 = {Object[2]@13310}
f5 = {Double@13311} -1.0
f6 = {Long@13312} 0
//生成模型
LinearModelData modelData = new LinearModelData();
......
modelData.coefVector = (DenseVector)value.f3;
modelData.labelValues = (Object[])value.f4;
//把模型資料轉換成List<Row> rows
RowCollector listCollector = new RowCollector();
new LinearModelDataConverter().save(modelData, listCollector);
List<Row> rows = listCollector.getRows();
for (Row r : rows) {
int rowSize = r.getArity();
for (int j = 0; j < rowSize; ++j) {
.....
//序列化
}
out.collect(row);
}
iter++;
}
}
0x03 線上預測
預測功能是在 FtrlPredictStreamOp 完成的。
// ftrl predict
FtrlPredictStreamOp predictResult = new FtrlPredictStreamOp(initModel)
.setVectorCol(vecColName)
.setPredictionCol("pred")
.setReservedCols(new String[]{labelColName})
.setPredictionDetailCol("details")
.linkFrom(model, featurePipelineModel.transform(splitter.getSideOutput(0)));
從上面程式碼我們可以看到
- FtrlPredict 功能同樣需要初始模型 initModel,我們也是把邏輯迴歸模型賦予它。這樣也是為了冷啟動,即當FTRL訓練模組還沒有產生模型之前,FTRL預測模組也是可以對其輸入資料做預測的。
- model 是 FtrlTrainStreamOp 的輸出,即 FTRL 的訓練輸出。所以 WriteModel 就直接把輸出傳給了 FtrlPredict功能。
- splitter.getSideOutput(0) 這裡是前面提到的測試輸入,就是測試資料集。
linkFrom函式完成了業務邏輯,大致功能如下:
- 使用
inputs[0].getDataStream().flatMap ------> partition ----> map ----> flatMap(new CollectModel())
得到了模型 LinearModelData modelstr; - 使用 DataStream.connect 把輸入的測試資料集 和 模型 LinearModelData modelstr關聯起來,這樣每個task都擁有了線上模型 modelstr,就可以通過
flatMap(new PredictProcess(...)
進行分散式預測; - 使用 setOutputTable 和 LinearModelMapper 把預測結果輸出;
即 FTRL的預測功能有三個輸入:
- 初始模型 initModel -----> 最後被 PredictProcess.open 載入,作為冷啟動的預測模型;
- 測試資料流 -----> 被 PredictProcess.flatMap1處理,進行預測;
- FTRL訓練階段產生的模型資料流 ----> 被 PredictProcess.flatMap2 處理,進行線上模型更新;
3.1 初始化
建構函式中完成了初始化,即獲取事先訓練好的邏輯迴歸模型。
public FtrlPredictStreamOp(BatchOperator model) {
super(new Params());
if (model != null) {
dataBridge = DirectReader.collect(model);
} else {
throw new IllegalArgumentException("Ftrl algo: initial model is null. Please set a valid initial model.");
}
}
3.2 獲取線上訓練模型
CollectModel完成了 獲取線上訓練模型 功能。
其邏輯主要是:模型被分成若干塊,其中 (long)inRow.getField(1) 這裡記錄了具體有多少塊。所以 flatMap 函式會把這些塊累積起來,最後組裝成模型,統一傳送給下游運算元。
具體是通過一個 HashMap<> buffers 來完成臨時拼裝/最後組裝的。
public static class CollectModel implements FlatMapFunction<Row, LinearModelData> {
private Map<Long, List<Row>> buffers = new HashMap<>(0);
@Override
public void flatMap(Row inRow, Collector<LinearModelData> out) throws Exception {
// 輸入引數如下
inRow = {Row@13389} "0,19,0,{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"},null"
fields = {Object[5]@13405}
0 = {Long@13406} 0
1 = {Long@13403} 19
2 = {Long@13406} 0
3 = "{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"}"
"
long id = (long)inRow.getField(0);
Long nTab = (long)inRow.getField(1);
Row row = new Row(inRow.getArity() - 2);
for (int i = 0; i < row.getArity(); ++i) {
row.setField(i, inRow.getField(i + 2));
}
if (buffers.containsKey(id) && buffers.get(id).size() == nTab.intValue() - 1) {
buffers.get(id).add(row);
// 如果累積完成,則組裝成模型
LinearModelData ret = new LinearModelDataConverter().load(buffers.get(id));
buffers.get(id).clear();
// 傳送給下游運算元。
out.collect(ret);
} else {
if (buffers.containsKey(id)) {
//如果有key。則往list新增。
buffers.get(id).add(row);
} else {
// 如果沒有key,則新增list
List<Row> buffer = new ArrayList<>(0);
buffer.add(row);
buffers.put(id, buffer);
}
}
}
}
//變數類似這種
this = {FtrlPredictStreamOp$CollectModel@13388}
buffers = {HashMap@13393} size = 1
{Long@13406} 0 -> {ArrayList@13431} size = 2
key = {Long@13406} 0
value = 0
value = {ArrayList@13431} size = 2
0 = {Row@13409} "0,{"hasInterceptItem":"true","vectorCol":"\"vec\"","modelName":"\"Logistic Regression\"","labelCol":null,"linearModelType":"\"LR\"","vectorSize":"30000"},null"
1 = {Row@13471} "1048576,{"featureColNames":null,"featureColTypes":null,"coefVector":{"data":[-0.7383426732137549,0.0,0.0,0.0,1.5885293675862704E-4,-4.834608575902738E-5,0.0,0.0,-6.754208708318643E-5,-1.5904172331763155E-4,0.0,-1.315219790338925E-4,0.0,-4.994749246390495E-4,0.0,2.755456604395511E-4,-9.616429481614131E-4,-9.601054004112163E-5,0.0,-1.6679174640370486E-4,0.0,......"
3.3 線上預測
PredictProcess 完成了線上預測功能,LinearModelMapper 是具體預測實現。
public static class PredictProcess extends RichCoFlatMapFunction<Row, LinearModelData, Row> {
private LinearModelMapper predictor = null;
private String modelSchemaJson;
private String dataSchemaJson;
private Params params;
private int iter = 0;
private DataBridge dataBridge;
}
3.3.1 載入預設定模型
其建構函式獲得了 FtrlPredictStreamOp 類的 dataBridge,即事先訓練好的邏輯迴歸模型。每一個Task都擁有完整的模型。
open函式會載入邏輯迴歸模型。
public void open(Configuration parameters) throws Exception {
this.predictor = new LinearModelMapper(TableUtil.fromSchemaJson(modelSchemaJson),
TableUtil.fromSchemaJson(dataSchemaJson), this.params);
if (dataBridge != null) {
// read init model
List<Row> modelRows = DirectReader.directRead(dataBridge);
LinearModelData model = new LinearModelDataConverter().load(modelRows);
this.predictor.loadModel(model);
}
}
3.3.2 線上預測
FtrlPredictStreamOp.flatMap1 函式完成了線上預測。
public void flatMap1(Row row, Collector<Row> collector) throws Exception {
collector.collect(this.predictor.map(row));
}
呼叫棧如下:
predictWithProb:157, LinearModelMapper (com.alibaba.alink.operator.common.linear)
predictResultDetail:114, LinearModelMapper (com.alibaba.alink.operator.common.linear)
map:90, RichModelMapper (com.alibaba.alink.common.mapper)
flatMap1:174, FtrlPredictStreamOp$PredictProcess (com.alibaba.alink.operator.stream.onlinelearning)
flatMap1:143, FtrlPredictStreamOp$PredictProcess (com.alibaba.alink.operator.stream.onlinelearning)
processElement1:53, CoStreamFlatMap (org.apache.flink.streaming.api.operators.co)
processRecord1:135, StreamTwoInputProcessor (org.apache.flink.streaming.runtime.io)
具體是通過 LinearModelMapper 完成。
public abstract class RichModelMapper extends ModelMapper {
public Row map(Row row) throws Exception {
if (isPredDetail) {
// 我們的示例程式碼在這裡
Tuple2<Object, String> t2 = predictResultDetail(row);
return this.outputColsHelper.getResultRow(row, Row.of(t2.f0, t2.f1));
} else {
return this.outputColsHelper.getResultRow(row, Row.of(predictResult(row)));
}
}
}
預測程式碼如下,可以看出來使用了sigmoid。
/**
* Predict the label information with the probability of each label.
*/
public Tuple2 <Object, Double[]> predictWithProb(Vector vector) {
double dotValue = MatVecOp.dot(vector, model.coefVector);
switch (model.linearModelType) {
case LR:
case SVM:
double prob = sigmoid(dotValue);
return new Tuple2 <>(dotValue >= 0 ? model.labelValues[0] : model.labelValues[1],
new Double[] {prob, 1 - prob});
}
}
3.3.3 線上更新模型
FtrlPredictStreamOp.flatMap2 函式完成了處理線上訓練輸出的模型資料流,線上更新模型。
LinearModelData引數是由CollectModel完成載入並且傳輸出來的。
在模型載入過程中,是不能預測的,沒有看到相關保護機制。如果我疏漏請大家指出。
public void flatMap2(LinearModelData linearModel, Collector<Row> collector) throws Exception {
this.predictor.loadModel(linearModel);
}
0x04 問題解答
針對之前我們提出的問題,現在總結歸納如下:
- 訓練階段和預測階段都有預製模型以應對"冷啟動"嘛?都有預製模型;
- 訓練階段和預測階段是如何關聯起來的?用 linkFrom 直接把訓練階段和預測階段的運算元連在一起;
- 如何把訓練出來的模型傳給預測階段?訓練階段用 Flink collector.collect 把模型發給下游運算元;
- 輸出模型時候,模型過大怎麼處理?線上訓練會 模型打散 之後分佈傳送給下游運算元;
- 線上訓練的模型通過什麼機制實現更新?是定時驅動更新嘛?定時更新;
- 預測階段載入模型過程中,還可以預測嘛?有沒有機制保證這段時間內也能預測?目前沒有發現類似保護機制;
- 訓練階段中,有哪些階段用到了並行處理?訓練過程中主要是FTRL演算法的"預測predict" 和 "更新引數"兩個部分,以及傳送模型;
- 預測階段中,有哪些階段用到了並行處理?預測過程中主要是分散式接受模型和分散式預測;
- 遇到高維向量如何處理?切分開嘛?切分處理;
0xFF 參考
線上機器學習FTRL(Follow-the-regularized-Leader)演算法介紹