Alink漫談(十九) :原始碼解析 之 分位點離散化Quantile
0x00 摘要
Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文將帶領大家來分析Alink中 Quantile 的實現。
因為Alink的公開資料太少,所以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會隨時更新。
本文緣由是因為想分析GBDT,發現GBDT涉及到Quantile的使用,所以只能先分析Quantile 。
0x01 背景概念
1.1 離散化
離散化:就是把無限空間中有限的個體對映到有限的空間中(分箱處理)。資料離散化操作大多是針對連續資料進行的,處理之後的資料值域分佈將從連續屬性變為離散屬性。
離散化方式會影響後續資料建模和應用效果:
- 使用決策樹往往傾向於少量的離散化區間,過多的離散化將使得規則過多受到碎片區間的影響。
- 關聯規則需要對所有特徵一起離散化,關聯規則關注的是所有特徵的關聯關係,如果對每個列單獨離散化將失去整體規則性。
連續資料的離散化結果可以分為兩類:
- 一類是將連續資料劃分為特定區間的集合,例如{(0,10], (10,20], (20,50],(50,100]};
- 一類是將連續資料劃分為特定類,例如類1、類2、類3;
1.2 分位數
分位數(Quantile),亦稱分位點,是指將一個隨機變數的概率分佈範圍分為幾個等份的數值點,常用的有中位數(即二分位數)、四分位數、百分位數等。
假如有1000個數字(正數),這些數字的5%, 30%, 50%, 70%, 99%分位數分別是 [3.0,5.0,6.0,9.0,12.0],這表明
- 有5%的數字分佈在0-3.0之間
- 有25%的數字分佈在3.0-5.0之間
- 有20%的數字分佈在5.0-6.0之間
- 有20%的數字分佈在6.0-9.0之間
- 有29%的數字分佈在9.0-12.0之間
- 有1%的數字大於12.0
這就是分位數的統計學理解。
因此求解某一組數字中某個數的分位數,只需要將該組數字進行排序,然後再統計小於等於該數的個數,除以總的數字個數即可。
確定p分位數位置的兩種方法
- position = (n+1)p
- position = 1 + (n-1)p
1.3 四分位數
這裡我們用四分位數做進一步說明。
四分位數 概念:把給定的亂序數值由小到大排列並分成四等份,處於三個分割點位置的數值就是四分位數。
第1四分位數 (Q1),又稱“較小四分位數”,等於該樣本中所有數值由小到大排列後第25%的數字。
第2四分位數 (Q2),又稱“中位數”,等於該樣本中所有數值由小到大排列後第50%的數字。
第3四分位數 (Q3),又稱“較大四分位數”,等於該樣本中所有數值由小到大排列後第75%的數字。
四分位距(InterQuartile Range, IQR)= 第3四分位數與第1四分位數的差距。
0x02 示例程式碼
Alink中完成分位數功能的是QuantileDiscretizer
。QuantileDiscretizer
輸入連續的特徵列,輸出分箱的類別特徵。
- 分位點離散可以計算選定列的分位點,然後使用這些分位點進行離散化。生成選中列對應的q-quantile,其中可以所有列指定一個,也可以每一列對應一個。
- 分箱數(所需離散的數目,即分為幾段)是通過引數
numBuckets
(桶數目)來指定的。 箱的範圍是通過使用近似演算法來得到的。
本文示例程式碼如下。
public class QuantileDiscretizerExample {
public static void main(String[] args) throws Exception {
NumSeqSourceBatchOp numSeqSourceBatchOp = new NumSeqSourceBatchOp(1001, 2000, "col0"); // 就是把1001 ~ 2000 這個連續數值分段
Pipeline pipeline = new Pipeline()
.add(new QuantileDiscretizer()
.setNumBuckets(6) // 指定分箱數數目
.setSelectedCols(new String[]{"col0"}));
List<Row> result = pipeline.fit(numSeqSourceBatchOp).transform(numSeqSourceBatchOp).collect();
System.out.println(result);
}
}
輸出
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
.....
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
.....
5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
0x03 總體邏輯
我們首先給出總體邏輯圖例
-------------------------------- 準備階段 --------------------------------
│
│
│
┌───────────────────┐
│ getSelectedCols │ 獲取需要分位的列名字
└───────────────────┘
│
│
│
┌─────────────────────┐
│ quantileNum │ 獲取分箱數
└─────────────────────┘
│
│
│
┌──────────────────────┐
│ Preprocessing.select │ 從輸入中根據列名字select出資料
└──────────────────────┘
│
│
│
-------------------------------- 預處理階段 --------------------------------
│
│
│
┌──────────────────────┐
│ quantile │ 後續步驟 就是 計算分位數
└──────────────────────┘
│
│
│
┌────────────────────────────────┐
│ countElementsPerPartition │ 在每一個partition中獲取該分割槽的所有元素個數
└────────────────────────────────┘
│ <task id, count in this task>
│
│
┌──────────────────────┐
│ sum(1) │ 這裡對第二個引數,即"count in this task"進行累積,得出所有元素的個數
└──────────────────────┘
│
│
│
┌──────────────────────┐
│ map │ 取出所有元素個數,cnt在後續會使用
└──────────────────────┘
│
│
│
│
┌──────────────────────┐
│ missingCount │ 分割槽查詢應選的列中,有哪些資料沒有被查到,比如zeroAsMissing, null, isNaN
└──────────────────────┘
│
│
│
┌────────────────┐
│ mapPartition │ 把輸入資料Row打散,對於Row中的子元素按照Row內順序一一傳送出來
└────────────────┘
│ <idx in row, item in row>, 即<row中第幾個元素,元素>
│
│
┌──────────────┐
│ pSort │ 將flatten資料進行排序
└──────────────┘
│ 返回的是二元組
│ f0: dataset which is indexed by partition id
│ f1: dataset which has partition id and count
│
│
-------------------------------- 計算階段 --------------------------------
│
│
│
┌─────────────────┐
│ MultiQuantile │ 後續都是具體計算步驟
└─────────────────┘
│
│
│
┌─────────────────┐
│ open │ 從廣播中獲取變數,初步處理counts(排序),totalCnt,missingCounts(排序)
└─────────────────┘
│
│
│
┌─────────────────┐
│ mapPartition │ 具體計算
└─────────────────┘
│
│
│
┌─────────────────┐
│ groupBy(0) │ 依據 列idx 分組
└─────────────────┘
│
│
│
┌─────────────────┐
│ reduceGroup │ 歸併排序
└─────────────────┘
│set(Tuple2<column idx, 真實資料值>)
│
│
-------------------------------- 序列化模型 --------------------------------
│
│
│
┌──────────────┐
│ reduceGroup │ 分組歸併
└──────────────┘
│
│
│
┌─────────────────┐
│ SerializeModel │ 序列化模型
└─────────────────┘
下面圖片是為了在手機上縮放適配展示。
QuantileDiscretizerTrainBatchOp.linkFrom如下:
public QuantileDiscretizerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
// 示例中設定了 .setSelectedCols(new String[]{"col0"}));, 所以這裡 quantileColNames 的數值是"col0
String[] quantileColNames = getSelectedCols();
int[] quantileNum = null;
// 示例中設定了 .setNumBuckets(6),所以這裡 quantileNum 是 quantileNum = {int[1]@2705} 0 = 6
if (getParams().contains(QuantileDiscretizerTrainParams.NUM_BUCKETS)) {
quantileNum = new int[quantileColNames.length];
Arrays.fill(quantileNum, getNumBuckets());
} else {
quantileNum = Arrays.stream(getNumBucketsArray()).mapToInt(Integer::intValue).toArray();
}
/* filter the selected column from input */
// 獲取了 選擇的列 "col0"
DataSet<Row> input = Preprocessing.select(in, quantileColNames).getDataSet();
// 計算分位數
DataSet<Row> quantile = quantile(
input, quantileNum,
getParams().get(HasRoundMode.ROUND_MODE),
getParams().get(Preprocessing.ZERO_AS_MISSING)
);
// 序列化模型
quantile = quantile.reduceGroup(
new SerializeModel(
getParams(),
quantileColNames,
TableUtil.findColTypesWithAssertAndHint(in.getSchema(), quantileColNames),
BinTypes.BinDivideType.QUANTILE
)
);
/* set output */
setOutput(quantile, new QuantileDiscretizerModelDataConverter().getModelSchema());
return this;
}
其總體邏輯如下:
- 獲取需要分位的列名字
- 獲取分箱數
- 從輸入中根據列名字select出資料
- 呼叫 quantile 計算分位數
- 呼叫 countElementsPerPartition 在每一個partition中獲取該分割槽的所有元素個數,返回<task id, count in this task>,然後 對於元素個數進行累積 sum(1) ,即"count in this task"進行累積,得出所有元素的個數 cnt;
- 分割槽查詢應選的列中,有哪些資料沒有被查到,從程式碼看,是zeroAsMissing, null, isNaN這幾種情況,然後依據 partition id 進行分組 groupBy(0) 累積求和,得到 missingCount;
- 把輸入資料Row打散,對於Row中的子元素按照Row內順序一一傳送出來,這就做到了把Row型別給flatten了, 返回flatten = <idx in row, item in row>, 即<row中第幾個元素,元素>;
- 將flatten資料進行排序,pSort是大規模分割槽排序,此時還沒有分類。pSort返回的是二元組sortedData,f0: dataset which is indexed by partition id, f1: dataset which has partition id and count;
- 呼叫 MultiQuantile ,對 sortedData.f0(f0: dataset which is indexed by partition id) 進行計算分位數;具體是分割槽計算 mapPartition:
- 累積,得到當前 task 的起始位置,即 n 個輸入資料中從哪個資料開始計算;
- 根據 taskId 從 counts 中得到了本 task 應該處理哪些資料,即資料的start,end位置;
- 把資料插入 allRows.add(value); value 可認為是 <partition id, 真實資料>;
- 呼叫 QIndex 計算分位數後設資料;quantileNum是分成幾段,q1就是每一段的大小。如果分成6段,則每一段的大小是1/6;
- 遍歷一直到分箱數,每次迴圈 呼叫 qIndex.genIndex(j) 獲取每個分箱的index。然後依據這個分箱的index從輸入資料中獲取真實資料值,這個 真實資料值 就是 真實資料的index。比如連續區域是 1001 ~ 2000,分成 6 份,則第一份呼叫 qIndex.genIndex(j) 得到 167,則根據167,獲取真實資料是 1001 + 167 = 1168,即在 1001 ~ 2000 中,第一個分位index 是 1168.
- 依據 列idx 分組,得到 set(Tuple2<column idx, 真實資料值>);
- 序列化模型
0x04 訓練
4.1 quantile
訓練是通過 quantile 完成的,大致包含以下步驟。
- 呼叫 countElementsPerPartition 在每一個partition中獲取該分割槽的所有元素個數,返回<task id, count in this task>,然後 對於元素個數進行累積 sum(1) ,即"count in this task"進行累積,得出所有元素的個數 cnt;
- 分割槽查詢應選的列中,有哪些資料沒有被查到,從程式碼看,是zeroAsMissing, null, isNaN這幾種情況,然後依據 partition id 進行分組 groupBy(0) 累積求和,得到 missingCount;
- 把輸入資料Row打散,對於Row中的子元素按照Row內順序一一傳送出來,這就做到了把Row型別給flatten了,返回flatten = <idx in row, item in row>, 即<row中第幾個元素,元素>;
- 將flatten資料進行排序,pSort是大規模分割槽排序,此時還沒有分類。pSort返回的是二元組sortedData,f0: dataset which is indexed by partition id, f1: dataset which has partition id and count;
- 呼叫 MultiQuantile ,對 sortedData.f0(f0: dataset which is indexed by partition id) 進行計算分位數。
具體如下
public static DataSet<Row> quantile(
DataSet<Row> input,
final int[] quantileNum,
final HasRoundMode.RoundMode roundMode,
final boolean zeroAsMissing) {
/* instance count of dataset */
// countElementsPerPartition 的作用是:在每一個partition中獲取該分割槽的所有元素個數,返回<task id, count in this task>。
DataSet<Long> cnt = DataSetUtils
.countElementsPerPartition(input)
.sum(1) // 這裡對第二個引數,即"count in this task"進行累積,得出所有元素的個數。
.map(new MapFunction<Tuple2<Integer, Long>, Long>() {
@Override
public Long map(Tuple2<Integer, Long> value) throws Exception {
return value.f1; // 取出所有元素個數
}
}); // cnt在後續會使用
/* missing count of columns */
// 會查詢應選的列中,有哪些資料沒有被查到,從程式碼看,是zeroAsMissing, null, isNaN這幾種情況
DataSet<Tuple2<Integer, Long>> missingCount = input
.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Long>>() {
public void mapPartition(Iterable<Row> values, Collector<Tuple2<Integer, Long>> out) {
StreamSupport.stream(values.spliterator(), false)
.flatMap(x -> {
long[] counts = new long[x.getArity()];
Arrays.fill(counts, 0L);
// 如果發現有資料沒有查到,就增加counts
for (int i = 0; i < x.getArity(); ++i) {
if (x.getField(i) == null
|| (zeroAsMissing && ((Number) x.getField(i)).doubleValue() == 0.0)
|| Double.isNaN(((Number)x.getField(i)).doubleValue())) {
counts[i]++;
}
}
return IntStream.range(0, x.getArity())
.mapToObj(y -> Tuple2.of(y, counts[y]));
})
.collect(Collectors.groupingBy(
x -> x.f0,
Collectors.mapping(x -> x.f1, Collectors.reducing((a, b) -> a + b))
)
)
.entrySet()
.stream()
.map(x -> Tuple2.of(x.getKey(), x.getValue().get()))
.forEach(out::collect);
}
})
.groupBy(0) //按第一個元素分組
.reduce(new RichReduceFunction<Tuple2<Integer, Long>>() {
@Override
public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> value1, Tuple2<Integer, Long> value2) {
return Tuple2.of(value1.f0, value1.f1 + value2.f1); //累積求和
}
});
/* flatten dataset to 1d */
// 把輸入資料打散。
DataSet<PairComparable> flatten = input
.mapPartition(new RichMapPartitionFunction<Row, PairComparable>() {
PairComparable pairBuff;
public void mapPartition(Iterable<Row> values, Collector<PairComparable> out) {
for (Row value : values) { // 遍歷分割槽內所有輸入元素
for (int i = 0; i < value.getArity(); ++i) { // 如果輸入元素Row本身包含多個子元素
pairBuff.first = i; // 則對於這些子元素按照Row內順序一一傳送出來,這就做到了把Row型別給flatten了
if (value.getField(i) == null
|| (zeroAsMissing && ((Number) value.getField(i)).doubleValue() == 0.0)
|| Double.isNaN(((Number)value.getField(i)).doubleValue())) {
pairBuff.second = null;
} else {
pairBuff.second = (Number) value.getField(i);
}
out.collect(pairBuff); // 返回<idx in row, item in row>, 即<row中第幾個元素,元素>
}
}
}
});
/* sort data */
// 將flatten資料進行排序,pSort是大規模分割槽排序,此時還沒有分類
// pSort返回的是二元組,f0: dataset which is indexed by partition id, f1: dataset which has partition id and count.
Tuple2<DataSet<PairComparable>, DataSet<Tuple2<Integer, Long>>> sortedData
= SortUtilsNext.pSort(flatten);
/* calculate quantile */
return sortedData.f0 //f0: dataset which is indexed by partition id
.mapPartition(new MultiQuantile(quantileNum, roundMode))
.withBroadcastSet(sortedData.f1, "counts") //f1: dataset which has partition id and count
.withBroadcastSet(cnt, "totalCnt")
.withBroadcastSet(missingCount, "missingCounts")
.groupBy(0) // 依據 列idx 分組
.reduceGroup(new RichGroupReduceFunction<Tuple2<Integer, Number>, Row>() {
@Override
public void reduce(Iterable<Tuple2<Integer, Number>> values, Collector<Row> out) {
TreeSet<Number> set = new TreeSet<>(new Comparator<Number>() {
@Override
public int compare(Number o1, Number o2) {
return SortUtils.OBJECT_COMPARATOR.compare(o1, o2);
}
});
int id = -1;
for (Tuple2<Integer, Number> val : values) {
// Tuple2<column idx, 資料>
id = val.f0;
set.add(val.f1);
}
// runtime變數
set = {TreeSet@9379} size = 5
0 = {Long@9389} 167 // 就是第 0 列的第一段 idx
1 = {Long@9392} 333 // 就是第 0 列的第二段 idx
2 = {Long@9393} 500
3 = {Long@9394} 667
4 = {Long@9382} 833
out.collect(Row.of(id, set.toArray(new Number[0])));
}
});
}
下面會對幾個重點函式做說明。
4.2 countElementsPerPartition
countElementsPerPartition 的作用是:在每一個partition中獲取該分割槽的所有元素個數。
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
long counter = 0;
for (T value : values) {
counter++; // 在每一個partition中獲取該分割槽的所有元素個數
}
out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
}
});
}
4.3 MultiQuantile
MultiQuantile用來計算具體的分位點。
open函式中會從廣播中獲取變數,初步處理counts(排序),totalCnt,missingCounts(排序)等等。
mapPartition函式則做具體計算,大致步驟如下:
- 累積,得到當前 task 的起始位置,即 n 個輸入資料中從哪個資料開始計算;
- 根據 taskId 從 counts 中得到了本 task 應該處理哪些資料,即資料的start,end位置;
- 把資料插入 allRows.add(value); value 可認為是 <partition id, 真實資料>;
- 呼叫 QIndex 計算分位數後設資料;quantileNum是分成幾段,q1就是每一段的大小。如果分成6段,則每一段的大小是1/6;
- 遍歷一直到分箱數,每次迴圈 呼叫 qIndex.genIndex(j) 獲取每個分箱的index。然後依據這個分箱的index從輸入資料中獲取真實資料值,這個 真實資料值 就是 真實資料的index。比如連續區域是 1001 ~ 2000,分成 6 份,則第一份呼叫 qIndex.genIndex(j) 得到 167,則根據167,獲取真實資料是 1001 + 167 = 1168,即在 1001 ~ 2000 中,第一個分位index 是 1168;
具體程式碼是:
public static class MultiQuantile
extends RichMapPartitionFunction<PairComparable, Tuple2<Integer, Number>> {
private List<Tuple2<Integer, Long>> counts;
private List<Tuple2<Integer, Long>> missingCounts;
private long totalCnt = 0;
private int[] quantileNum;
private HasRoundMode.RoundMode roundType;
private int taskId;
@Override
public void open(Configuration parameters) throws Exception {
// 從廣播中獲取變數,初步處理counts(排序),totalCnt,missingCounts(排序)。
// 之前設定廣播變數.withBroadcastSet(sortedData.f1, "counts"),其中 f1 的格式是: dataset which has partition id and count,所以就是用 partition id來排序
this.counts = getRuntimeContext().getBroadcastVariableWithInitializer(
"counts",
new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
@Override
public List<Tuple2<Integer, Long>> initializeBroadcastVariable(
Iterable<Tuple2<Integer, Long>> data) {
ArrayList<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
for (Tuple2<Integer, Long> datum : data) {
sortedData.add(datum);
}
//排序
sortedData.sort(Comparator.comparing(o -> o.f0));
// runtime的資料如下,本機有4核,所以資料分為4個 partition,每個partition的資料分別為251,250,250,250
sortedData = {ArrayList@9347} size = 4
0 = {Tuple2@9350} "(0,251)" // partition 0, 資料個數是251
1 = {Tuple2@9351} "(1,250)"
2 = {Tuple2@9352} "(2,250)"
3 = {Tuple2@9353} "(3,250)"
return sortedData;
}
});
this.totalCnt = getRuntimeContext().getBroadcastVariableWithInitializer("totalCnt",
new BroadcastVariableInitializer<Long, Long>() {
@Override
public Long initializeBroadcastVariable(Iterable<Long> data) {
return data.iterator().next();
}
});
this.missingCounts = getRuntimeContext().getBroadcastVariableWithInitializer(
"missingCounts",
new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
@Override
public List<Tuple2<Integer, Long>> initializeBroadcastVariable(
Iterable<Tuple2<Integer, Long>> data) {
return StreamSupport.stream(data.spliterator(), false)
.sorted(Comparator.comparing(o -> o.f0))
.collect(Collectors.toList());
}
}
);
taskId = getRuntimeContext().getIndexOfThisSubtask();
// runtime的資料如下
this = {QuantileDiscretizerTrainBatchOp$MultiQuantile@9348}
counts = {ArrayList@9347} size = 4
0 = {Tuple2@9350} "(0,251)"
1 = {Tuple2@9351} "(1,250)"
2 = {Tuple2@9352} "(2,250)"
3 = {Tuple2@9353} "(3,250)"
missingCounts = {ArrayList@9375} size = 1
0 = {Tuple2@9381} "(0,0)"
totalCnt = 1001
quantileNum = {int[1]@9376}
0 = 6
roundType = {HasRoundMode$RoundMode@9377} "ROUND"
taskId = 2
}
@Override
public void mapPartition(Iterable<PairComparable> values, Collector<Tuple2<Integer, Number>> out) throws Exception {
long start = 0;
long end;
int curListIndex = -1;
int size = counts.size(); // 分成4份,所以這裡是4
for (int i = 0; i < size; ++i) {
int curId = counts.get(i).f0; // 取出輸入元素中的 partition id
if (curId == taskId) {
curListIndex = i; // 當前 task 對應哪個 partition id
break; // 到了當前task,就可以跳出了
}
start += counts.get(i).f1; // 累積,得到當前 task 的起始位置,即1000個資料中從哪個資料開始計算
}
// 根據 taskId 從counts中得到了本 task 應該處理哪些資料,即資料的start,end位置
// 本 partition 是 0,其中有251個資料
end = start + counts.get(curListIndex).f1; // end = 起始位置 + 此partition的資料個數
ArrayList<PairComparable> allRows = new ArrayList<>((int) (end - start));
for (PairComparable value : values) {
allRows.add(value); // value 可認為是 <partition id, 真實資料>
}
allRows.sort(Comparator.naturalOrder());
// runtime變數
start = 0
curListIndex = 0
size = 4
end = 251
allRows = {ArrayList@9406} size = 251
0 = {PairComparable@9408}
first = {Integer@9397} 0
second = {Long@9434} 0
1 = {PairComparable@9409}
first = {Integer@9397} 0
second = {Long@9435} 1
2 = {PairComparable@9410}
first = {Integer@9397} 0
second = {Long@9439} 2
......
// size = ((251 - 1) / 1001 - 0 / 1001) + 1 = 1
size = (int) ((end - 1) / totalCnt - start / totalCnt) + 1;
int localStart = 0;
for (int i = 0; i < size; ++i) {
int fIdx = (int) (start / totalCnt + i);
int subStart = 0;
int subEnd = (int) totalCnt;
if (i == 0) {
subStart = (int) (start % totalCnt); // 0
}
if (i == size - 1) {
subEnd = (int) (end % totalCnt == 0 ? totalCnt : end % totalCnt); // 251
}
if (totalCnt - missingCounts.get(fIdx).f1 == 0) {
localStart += subEnd - subStart;
continue;
}
QIndex qIndex = new QIndex(
totalCnt - missingCounts.get(fIdx).f1, quantileNum[fIdx], roundType);
// runtime變數
qIndex = {QuantileDiscretizerTrainBatchOp$QIndex@9548}
totalCount = 1001.0
q1 = 0.16666666666666666
roundMode = {HasRoundMode$RoundMode@9377} "ROUND"
// 遍歷,一直到分箱數。
for (int j = 1; j < quantileNum[fIdx]; ++j) {
// 獲取每個分箱的index
long index = qIndex.genIndex(j); // j = 1 ---> index = 167,就是把 1001 個分為6段,第一段終點是167
//對應本 task = 0,subStart = 0,subEnd = 251。則index = 167,直接從allRows獲取第167個,數值是 1168。因為連續區域是 1001 ~ 2000,所以第167個對應數值就是1168
//如果本 task = 1,subStart = 251,subEnd = 501。則index = 333,直接從allRows獲取第 (333 + 0 - 251)= 第 82 個,獲取其中的數值。這裡因為數值區域是 1001 ~ 2000, 所以數值是1334。
if (index >= subStart && index < subEnd) { // idx剛剛好在本分割槽的資料中
PairComparable pairComparable = allRows.get(
(int) (index + localStart - subStart)); //
// runtime變數
pairComparable = {PairComparable@9581}
first = {Integer@9507} 0 // first是column idx
second = {Long@9584} 167 // 真實資料
out.collect(Tuple2.of(pairComparable.first, pairComparable.second));
}
}
localStart += subEnd - subStart;
}
}
}
4.4 QIndex
其中 QIndex 是本文關鍵所在,就是具體計算分位數。
- 建構函式中會得倒所有元素個數,每段大小;
- genIndex函式中會具體計算,比如假設還是6段,則如果取第一段,則k=1,其index為 (1/6 * (1001 - 1) * 1) = 167
public static class QIndex {
private double totalCount;
private double q1;
private HasRoundMode.RoundMode roundMode;
public QIndex(double totalCount, int quantileNum, HasRoundMode.RoundMode type) {
this.totalCount = totalCount; // 1001,所有元素的個數
this.q1 = 1.0 / (double) quantileNum; // 1.0 / 6 = 16666666666666666。quantileNum是分成幾段,q1就是每一段的大小。如果分成6段,則每一段的大小是1/6
this.roundMode = type;
}
public long genIndex(int k) {
// 假設還是6段,則如果取第一段,則k=1,其index為 (1/6 * (1001 - 1) * 1) = 167
return roundMode.calc(this.q1 * (this.totalCount - 1.0) * (double) k);
}
}
0x05 輸出模型
輸出模型是通過 reduceGroup 呼叫 SerializeModel 來完成。
具體邏輯是:
- 先構建分箱點後設資料資訊;
- 然後序列化成模型;
// 序列化模型
quantile = quantile.reduceGroup(
new SerializeModel(
getParams(),
quantileColNames,
TableUtil.findColTypesWithAssertAndHint(in.getSchema(), quantileColNames),
BinTypes.BinDivideType.QUANTILE
)
);
SerializeModel 的具體實現是:
public static class SerializeModel implements GroupReduceFunction<Row, Row> {
private Params meta;
private String[] colNames;
private TypeInformation<?>[] colTypes;
private BinTypes.BinDivideType binDivideType;
@Override
public void reduce(Iterable<Row> values, Collector<Row> out) throws Exception {
Map<String, FeatureBorder> m = new HashMap<>();
for (Row val : values) {
int index = (int) val.getField(0);
Number[] splits = (Number[]) val.getField(1);
m.put(
colNames[index],
QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder(
colNames[index],
colTypes[index],
splits,
meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN),
binDivideType
)
);
}
for (int i = 0; i < colNames.length; ++i) {
if (m.containsKey(colNames[i])) {
continue;
}
m.put(
colNames[i],
QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder(
colNames[i],
colTypes[i],
null,
meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN),
binDivideType
)
);
}
QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(m, meta);
model.save(model, out);
}
}
這裡用到了 FeatureBorder 類。
資料分箱是按照某種規則將資料進行分類。就像可以將水果按照大小進行分類,售賣不同的價格一樣。
FeatureBorder 就是專門為了 Featureborder for binning, discrete Featureborder and continuous Featureborder。
我們能夠看出來,該分箱對應的列名,index,各個分割點。
m = {HashMap@9380} size = 1
"col0" -> {FeatureBorder@9438} "{"binDivideType":"QUANTILE","featureName":"col0","bin":{"NORM":[{"index":0},{"index":1},{"index":2},{"index":3},{"index":4},{"index":5}],"NULL":{"index":6}},"featureType":"BIGINT","splitsArray":[1168,1334,1501,1667,1834],"isLeftOpen":true,"binCount":6}"
0x06 預測
預測是在 QuantileDiscretizerModelMapper 中完成的。
6.1 載入模型
模型資料是
model = {QuantileDiscretizerModelDataConverter@9582}
meta = {Params@9670} "Params {selectedCols=["col0"], version="v2", numBuckets=6}"
data = {HashMap@9584} size = 1
"col0" -> {FeatureBorder@9676} "{"binDivideType":"QUANTILE","featureName":"col0","bin":{"NORM":[{"index":0},{"index":1},{"index":2},{"index":3},{"index":4},{"index":5}],"NULL":{"index":6}},"featureType":"BIGINT","splitsArray":[1168,1334,1501,1667,1834],"isLeftOpen":true,"binCount":6}"
loadModel會完成載入。
@Override
public void loadModel(List<Row> modelRows) {
QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter();
model.load(modelRows);
for (int i = 0; i < mapperBuilder.paramsBuilder.selectedCols.length; i++) {
FeatureBorder border = model.data.get(mapperBuilder.paramsBuilder.selectedCols[i]);
List<Bin.BaseBin> norm = border.bin.normBins;
int size = norm.size();
Long maxIndex = norm.get(0).getIndex();
Long lastIndex = norm.get(size - 1).getIndex();
for (int j = 0; j < norm.size(); ++j) {
if (maxIndex < norm.get(j).getIndex()) {
maxIndex = norm.get(j).getIndex();
}
}
long maxIndexWithNull = Math.max(maxIndex, border.bin.nullBin.getIndex());
switch (mapperBuilder.paramsBuilder.handleInvalidStrategy) {
case KEEP:
mapperBuilder.vectorSize.put(i, maxIndexWithNull + 1);
break;
case SKIP:
case ERROR:
mapperBuilder.vectorSize.put(i, maxIndex + 1);
break;
default:
throw new UnsupportedOperationException("Unsupported now.");
}
if (mapperBuilder.paramsBuilder.dropLast) {
mapperBuilder.dropIndex.put(i, lastIndex);
}
mapperBuilder.discretizers[i] = createQuantileDiscretizer(border, model.meta);
}
mapperBuilder.setAssembledVectorSize();
}
載入中,最後呼叫 createQuantileDiscretizer 生成 LongQuantileDiscretizer。這就是針對Long型別的離散器。
public static class LongQuantileDiscretizer implements NumericQuantileDiscretizer {
long[] bounds;
boolean isLeftOpen;
int[] boundIndex;
int nullIndex;
boolean zeroAsMissing;
@Override
public int findIndex(Object number) {
if (number == null) {
return nullIndex;
}
long lVal = ((Number) number).longValue();
if (isMissing(lVal, zeroAsMissing)) {
return nullIndex;
}
int hit = Arrays.binarySearch(bounds, lVal);
if (isLeftOpen) {
hit = hit >= 0 ? hit - 1 : -hit - 2;
} else {
hit = hit >= 0 ? hit : -hit - 2;
}
return boundIndex[hit];
}
}
其數值如下:
this = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768}
bounds = {long[7]@9757}
0 = -9223372036854775807
1 = 1168
2 = 1334
3 = 1501
4 = 1667
5 = 1834
6 = 9223372036854775807
isLeftOpen = true
boundIndex = {int[7]@9743}
0 = 0 // -9223372036854775807 ~ 1168 之間對應的最終分箱離散值是 0
1 = 1
2 = 2
3 = 3
4 = 4
5 = 5
6 = 5 // 1834 ~ 9223372036854775807 之間對應的最終分箱離散值是 5
nullIndex = 6
zeroAsMissing = false
6.2 預測
預測 QuantileDiscretizerModelMapper 的 DiscretizerMapperBuilder 完成。
Row map(Row row){
// 這裡的 row 舉例是: row = {Row@9743} "1003"
for (int i = 0; i < paramsBuilder.selectedCols.length; i++) {
int colIdxInData = selectedColIndicesInData[i];
Object val = row.getField(colIdxInData);
int foundIndex = discretizers[i].findIndex(val); // 找到 1003對應的index,就是呼叫Discretizer完成,這裡找到 foundIndex 是0
predictIndices[i] = (long) foundIndex;
}
return paramsBuilder.outputColsHelper.getResultRow(
row,
setResultRow(
predictIndices,
paramsBuilder.encode,
dropIndex,
vectorSize,
paramsBuilder.dropLast,
assembledVectorSize) // 最後返回離散值是0
);
}
this = {QuantileDiscretizerModelMapper$DiscretizerMapperBuilder@9744}
paramsBuilder = {QuantileDiscretizerModelMapper$DiscretizerParamsBuilder@9752}
selectedColIndicesInData = {int[1]@9754}
vectorSize = {HashMap@9758} size = 1
dropIndex = {HashMap@9759} size = 1
assembledVectorSize = {Integer@9760} 6
discretizers = {QuantileDiscretizerModelMapper$NumericQuantileDiscretizer[1]@9761}
0 = {QuantileDiscretizerModelMapper$LongQuantileDiscretizer@9768}
bounds = {long[7]@9776}
isLeftOpen = true
boundIndex = {int[7]@9777}
nullIndex = 6
zeroAsMissing = false
predictIndices = {Long[1]@9763}
0xFF 參考
Spark QuantileDiscretizer 分位數離散器