Alink漫談(十八) :原始碼解析 之 多列字串編碼MultiStringIndexer
0x00 摘要
Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文將帶領大家來分析Alink中 MultiStringIndexer 的實現。
因為Alink的公開資料太少,所以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會隨時更新。
本文緣由是想分析GBDT,發現GBDT涉及到MultiStringIndexer的使用,所以只能先分析MultiStringIndexer 。
0x01 概念
Alink的官方介紹是:MultiStringIndexer訓練元件的作用是訓練一個模型用於將多列字串對映為整數。
具體來說,StringIndexer(字串-索引變換)將標籤的"字串列"編碼為"標籤索引的列"。
- 標籤索引序列的取值範圍是[0,numLabels(字串中所有出現的單詞去掉重複的詞後的總和)],按照標籤出現頻率排序,出現最多的標籤索引為0(具體為升序降序是可以配置的)。
- 如果輸入是數值型,我們先將數值對映到字串,再對字串進行索引化。
- 如果下游的pipeline(例如:Estimator或者Transformer)需要用到索引化後的標籤序列,則需要將這個pipeline的輸入列名字指定為索引化序列的名字。大部分情況下,通過setSelectedCols設定輸入的列名。
以這些輸入為例:
("football", "can"),
("football", "hhh"),
("football", "zzz"),
("basketball", "zzz"),
("basketball", "can"),
("tennis", "can")
對於第一列,MultiStringIndexer 對資料集的label進行重新編號。按label出現的頻次,轉換成0 ~ numOfLabels - 1(分類個數)。如果是按照從高到低排序,則頻次最高的轉換為0,以此類推,比如:
- football,出現次數最多,出現了3次,轉換(編號)為0
- 其次是basketball,出現了2次,編號為1,以此類推。
在應用StringIndexer對labels進行重新編號後,帶著這些編號後的label對資料進行了訓練,並接著對其他資料進行了預測,得到預測結果,預測結果的label也是重新編號過的,因此需要轉換回來。
0x02 示例程式碼
示例程式碼如下,本示例程式碼中,是按照升序排列,即football總數為3,則其idx為3,tennis個數為1,其idx為0:
public class MultiStringIndexerExample {
static AlgoOperator getData(boolean isBatch) {
Row[] array = new Row[] {
Row.of("football", "can"),
Row.of("football", "hhh"),
Row.of("football", "zzz"),
Row.of("basketball", "zzz"),
Row.of("basketball", "can"),
Row.of("tennis", "can")
};
if (isBatch) {
return new MemSourceBatchOp(
Arrays.asList(array), new String[] {"a", "b"});
} else {
return new MemSourceStreamOp(
Arrays.asList(array), new String[] {"a", "b"});
}
}
public static void main(String[] args) throws Exception {
BatchOperator data = (BatchOperator)getData(true);
MultiStringIndexer stringindexer = new MultiStringIndexer()
.setSelectedCols("a", "b")
.setOutputCols("a_indexed", "b_indexed")
.setStringOrderType("frequency_asc");
stringindexer.fit(data).transform(data).print();
}
}
輸出如下:
a|b|a_indexed|b_indexed
-|-|---------|---------
football|can|2|2
football|hhh|2|0
football|zzz|2|1
basketball|zzz|1|1
basketball|can|1|2
tennis|can|0|2
轉換成表格看的更清楚。
a | b | a_indexed | b_indexed |
---|---|---|---|
football | can | 2 | 2 |
football | hhh | 2 | 0 |
football | zzz | 2 | 1 |
basketball | zzz | 1 | 1 |
basketball | can | 1 | 2 |
tennis | can | 0 | 2 |
0x03 總體邏輯
我們先給出一個流程圖
老套路,我們從 MultiStringIndexerTrainBatchOp.linkFrom開始挖掘。
@Override
public MultiStringIndexerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
BatchOperator<?> in = checkAndGetFirst(inputs);
// 示例中有 .setSelectedCols("a", "b"),這裡是取出具體列名字
final String[] selectedColNames = getSelectedCols();
// 獲取列的型別
final String[] selectedColSqlType = new String[selectedColNames.length];
for (int i = 0; i < selectedColNames.length; i++) {
selectedColSqlType[i] = FlinkTypeConverter.getTypeString(
TableUtil.findColTypeWithAssertAndHint(in.getSchema(), selectedColNames[i]));
}
// runtime列印資料
selectedColNames = {String[2]@2536}
0 = "a"
1 = "b"
selectedColSqlType = {String[2]@2537}
0 = "VARCHAR"
1 = "VARCHAR"
// 獲取選取列對應的資料
DataSet<Row> inputRows = in.select(selectedColNames).getDataSet();
//
DataSet<Tuple3<Integer, String, Long>> indexedToken =
StringIndexerUtil.indexTokens(inputRows, getStringOrderType(), 0L, true);
DataSet<Row> values = indexedToken
.mapPartition(new RichMapPartitionFunction<Tuple3<Integer, String, Long>, Row>() {
@Override
public void mapPartition(Iterable<Tuple3<Integer, String, Long>> values, Collector<Row> out)
throws Exception {
Params meta = null;
if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
// 第一個task會做這個計算,就是把列名,列型別作為後設資料傳送
meta = new Params().set(HasSelectedCols.SELECTED_COLS, selectedColNames)
.set(HasSelectedColTypes.SELECTED_COL_TYPES, selectedColSqlType);
}
// runtime列印資料
meta = {Params@9311} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
params = {HashMap@9316} size = 2
new MultiStringIndexerModelDataConverter().save(Tuple2.of(meta, values), out);
}
})
.name("build_model");
this.setOutput(values, new MultiStringIndexerModelDataConverter().getModelSchema());
return this;
}
訓練過程總體邏輯總結如下:
- 取出具體列名字,列的型別;
- 獲取"選取列"對應的資料;
- 把列名,列型別作為後設資料傳送;
StringIndexerUtil.indexTokens
給各個列的不同字串賦予連續的indices。每列的 indices 彼此不相關;- 呼叫到
indexSortedByFreq(data, startIndex, ignoreNull, true)
,作用是給各個列的不同字串賦予連續的indices,indices是按照字串出現的頻率排序;- 呼叫到 countTokens的作用是按照 "列idx","word" 來合併計算單詞個數,得到<"列idx","word",單詞個數>,比如第一列中,football這個單詞的個數是3,則返回三元組是 <0,football,3>,其中列的idx從0開始計算。
- 呼叫 flattenTokens 把輸入資料 Row 給打散,返回 A DataSet of tuples of column index and token,即<"列idx","word">。比如對於 Row.of("football", "can") 這個輸入,flattenTokens 輸出兩個Tuple2 ,<0, "football"> 和 <1, "can">。
- 對上面結構進行map操作,輸出<column idx, word, 1L>,比如 <0, "football", 1L> ;
- 按照 "列idx","word" 來分組;
- 按照 "列idx","word" 來合併計算單詞個數;
- indexSortedByFreq會對countTokens返回的結果<"列idx","word",詞頻>處理;
- 首先按照 列idx 做分組;
- 然後在上面結果基礎上,按照單詞個數排序;
- 排序的index是以輸入引數startIndex開始,startIndex在這裡是0;
- 最後得到 第一列的 (0,football,0),(0,basketball,1),(0,football,2);第二列的資料 (1,hhh,0),(1,zzz,1),(1,can,2);
- 呼叫到 countTokens的作用是按照 "列idx","word" 來合併計算單詞個數,得到<"列idx","word",單詞個數>,比如第一列中,football這個單詞的個數是3,則返回三元組是 <0,football,3>,其中列的idx從0開始計算。
- 呼叫到
- 把indexTokens的結果儲存為模型,其中使用之前提到的 "把列名,列型別作為後設資料"。
下面具體剖析後兩個階段。
0x04 Add Index to Token
這部分就是給各個列的不同字串賦予連續的indices。每列的 indices 彼此不相關。
具體是由StringIndexerUtil.indexTokens 做到的。
public static DataSet<Tuple3<Integer, String, Long>> indexTokens(
DataSet<Row> data, HasStringOrderTypeDefaultAsRandom.StringOrderType orderType,
final long startIndex, final boolean ignoreNull) {
case FREQUENCY_ASC:
return indexSortedByFreq(data, startIndex, ignoreNull, true);
}
4.1 合併計算單詞個數
indexSortedByFreq會呼叫countTokens來計算單詞個數,所以我們先看countTokens。
countTokens的作用是按照 "列idx","word" 來合併計算單詞個數,比如第一列中,football這個單詞的個數是3,則返回三元組是 <0,football,3>,其中列的idx從0開始計算。
具體邏輯如下:
- 呼叫 flattenTokens 把輸入資料 Row 給打散,返回 A DataSet of tuples of column index and token,即<"列idx","word">。比如對於 Row.of("football", "can") 這個輸入,flattenTokens 輸出兩個Tuple2 ,<0, "football"> 和 <1, "can">。
- 對上面結果進行map操作,輸出<column idx, word, 1L>,比如 <0, "football", 1L> ,這個是計數的常規操作。
- 按照 "列idx","word" 來分組;
- 按照 "列idx","word" 來合併計算單詞個數,就是不停歸併上面的 1L。
4.1.1 打散輸入資料
其中 flattenTokens 的作用是把輸入資料 Row 給打散,返回 A DataSet of tuples of column index and token.。
比如對於 Row.of("football", "can") 這個輸入,flattenTokens 使用 out.collect(Tuple2.of(i, String.valueOf(o)));
輸出兩個Tuple2。
value = {Row@9212} "football,can"
fields = {Object[2]@9215}
0 = "football"
1 = "can"
輸出 <0, "football"> 和 <1, "can">
4.1.2 分組計算個數
這是通過flattenTokens的結果進行 map,groupBy,reduce的一系列操作完成的。
具體程式碼如下:
public static DataSet<Tuple3<Integer, String, Long>> countTokens(DataSet<Row> data, final boolean ignoreNull) {
return flattenTokens(data, ignoreNull) // 把輸入資料 Row 給打散
.map(new MapFunction<Tuple2<Integer, String>, Tuple3<Integer, String, Long>>() {
@Override
public Tuple3<Integer, String, Long> map(Tuple2<Integer, String> value) throws Exception {
return Tuple3.of(value.f0, value.f1, 1L); // 輸出<column idx, word, 1L>,比如 <0, "football", 1L>
}
})
.groupBy(0, 1) // 按照 "列idx","word" 來分組
.reduce(new ReduceFunction<Tuple3<Integer, String, Long>>() {
@Override
public Tuple3<Integer, String, Long> reduce(Tuple3<Integer, String, Long> value1, Tuple3<Integer, String, Long> value2) throws Exception {
value1.f2 += value2.f2;
return value1; // 按照 "列idx","word" 來合併計算單詞個數
}
})
.name("count_tokens");
}
// reduce之後發出
value1 = {Tuple3@9284} "(0,football,3)"
f0 = {Integer@9226} 0
f1 = "football"
f2 = {Long@9295} 3
4.2 合併計算單詞個數
前面 countTokens的 返回三元組是 <列idx","word" ,詞頻>,其中列的idx從0開始計算。
indexSortedByFreq會對countTokens返回的結果<"列idx","word",詞頻>處理;
- 首先按照 列idx 做分組;
- 然後在上面結果基礎上,按照單詞個數排序;
- 排序的index是以輸入引數startIndex開始,startIndex在這裡是0;
- 最後得到 第一列的 (0,tennis,0),(0,basketball,1),(0,football,2);第二列的資料 (1,hhh,0),(1,zzz,1),(1,can,2);
具體程式碼如下:
public static DataSet<Tuple3<Integer, String, Long>> indexSortedByFreq(
DataSet<Row> data, final long startIndex, final boolean ignoreNull, final boolean isAscending) {
return countTokens(data, ignoreNull)
.groupBy(0) //按照 列idx 做分組
.sortGroup(2, isAscending ? Order.ASCENDING : Order.DESCENDING) //按照單詞個數排序
.reduceGroup(new GroupReduceFunction<Tuple3<Integer, String, Long>, Tuple3<Integer, String, Long>>() {
@Override
public void reduce(Iterable<Tuple3<Integer, String, Long>> values,
Collector<Tuple3<Integer, String, Long>> out) {
long id = startIndex;
for (Tuple3<Integer, String, Long> value : values) {
out.collect(Tuple3.of(value.f0, value.f1, id++)); // 歸併
}
}
});
}
0x05 輸出模型
這部分分為兩部分:
- 輸出後設資料,就是之前得到的 "把列名,列型別作為後設資料"。
- 輸出具體每一列的每一個單詞資訊,比如 第一列的 (0,tennis,0),(0,basketball,1),(0,football,2);第二列的資料 (1,hhh,0),(1,zzz,1),(1,can,2);
public class MultiStringIndexerModelDataConverter implements
ModelDataConverter<Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>>, MultiStringIndexerModelData> {
@Override
public void save(Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>> modelData, Collector<Row> collector) {
if (modelData.f0 != null) {
collector.collect(Row.of(-1L, modelData.f0.toJson(), null));
}
modelData.f1.forEach(tuple -> {
collector.collect(Row.of(tuple.f0.longValue(), tuple.f1, tuple.f2));
});
}
}
tuple = {Tuple3@9405} "(0,tennis,0)"
f0 = {Integer@9406} 0
f1 = "tennis"
f2 = {Long@9408} 0
0x06 預測
預測功能是在 ModelMapperAdapter 完成的。
public class ModelMapperAdapter extends RichMapFunction<Row, Row> implements Serializable {
private final ModelMapper mapper;
private final ModelSource modelSource;
@Override
public void open(Configuration parameters) throws Exception {
List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext());
this.mapper.loadModel(modelRows); //載入模型
}
@Override
public Row map(Row row) throws Exception {
return this.mapper.map(row); //預測
}
}
6.1 載入模型
MultiStringIndexerModelDataConverter中我們會進行模型載入。
- 首先會載入元資訊
- 其次會逐條載入模型資訊
public MultiStringIndexerModelData load(List<Row> rows) {
MultiStringIndexerModelData modelData = new MultiStringIndexerModelData();
modelData.tokenAndIndex = new ArrayList<>();
modelData.tokenNumber = new HashMap<>();
for (Row row : rows) {
long colIndex = (Long) row.getField(0);
if (colIndex < 0L) { // 後設資料
modelData.meta = Params.fromJson((String) row.getField(1));
} else { // 具體模型資訊
int columnIndex = ((Long) row.getField(0)).intValue();
Long tokenIndex = Long.valueOf(String.valueOf(row.getField(2)));
modelData.tokenAndIndex.add(Tuple3.of(columnIndex, (String) row.getField(1), tokenIndex));
modelData.tokenNumber.merge(columnIndex, 1L, Long::sum); // 合併列資料個數
}
}
// To ensure that every columns has token number.
int numFields = 0;
if (modelData.meta != null) {
numFields = modelData.meta.get(HasSelectedCols.SELECTED_COLS).length;
}
for (int i = 0; i < numFields; i++) {
modelData.tokenNumber.merge(i, 0L, Long::sum);
}
return modelData;
}
最後模型內容如下,其中 tokenNumber 表示每列的資料有幾個,tokenAndIndex表示具體資訊,比如(0,tennis,0),(0,basketball,1),(0,football,2) 就表示他們都是第一列的,basketball轉換後的資料是 1:
modelData = {MultiStringIndexerModelData@9348}
meta = {Params@9440} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
tokenAndIndex = {ArrayList@9360} size = 6
0 = {Tuple3@9472} "(0,football,2)"
1 = {Tuple3@9511} "(0,tennis,0)"
2 = {Tuple3@9512} "(1,zzz,1)"
3 = {Tuple3@9513} "(1,hhh,0)"
4 = {Tuple3@9514} "(0,basketball,1)"
5 = {Tuple3@9515} "(1,can,2)"
tokenNumber = {HashMap@9385} size = 2
{Integer@9507} 0 -> {Long@9508} 3
{Integer@9509} 1 -> {Long@9508} 3
numFields = 2
6.2 預測
預測是在 MultiStringIndexerModelMapper 完成的。
// 假設輸入是:row = {Row@9309} "football,can"
// 選擇的列是:selectedColNames = {String[2]@9314} 0 = "a" 1 = "b"
// 模型對映器是:
this = {MultiStringIndexerModelMapper@9309}
indexMapper = {HashMap@9318} size = 2
{Integer@9357} 0 -> {HashMap@9314} size = 3
key = {Integer@9357} 0
value = 0
value = {HashMap@9314} size = 3
"basketball" -> {Long@9386} 1
"football" -> {Long@9332} 2
"tennis" -> {Long@9384} 0
{Integer@9352} 1 -> {HashMap@9358} size = 3
key = {Integer@9352} 1
value = 1
value = {HashMap@9358} size = 3
"can" -> {Long@9332} 2
"hhh" -> {Long@9384} 0
"zzz" -> {Long@9386} 1
則經歷過下列程式碼,最後就可以進行預測
public Row map(Row row) throws Exception {
Row result = new Row(selectedColNames.length);
for (int i = 0; i < selectedColNames.length; i++) {
Map<String, Long> mapper = indexMapper.get(i);
int colIdxInData = selectedColIndicesInData[i];
Object val = row.getField(colIdxInData);
String key = val == null ? null : String.valueOf(val);
Long index = mapper.get(key);
if (index != null) {
result.setField(i, index); // 我們主要執行在這裡
} else {
}
}
// 最後預測結果是:
row = {Row@9308} "football,can"
result = {Row@9313} "2,2"
return outputColsHelper.getResultRow(row, result);
}