Alink漫談(十八) :原始碼解析 之 多列字串編碼MultiStringIndexer

羅西的思考發表於2020-08-15

Alink漫談(十八) :原始碼解析 之 多列字串編碼MultiStringIndexer

0x00 摘要

Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文將帶領大家來分析Alink中 MultiStringIndexer 的實現。

因為Alink的公開資料太少,所以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會隨時更新。

本文緣由是想分析GBDT,發現GBDT涉及到MultiStringIndexer的使用,所以只能先分析MultiStringIndexer 。

0x01 概念

Alink的官方介紹是:MultiStringIndexer訓練元件的作用是訓練一個模型用於將多列字串對映為整數。

具體來說,StringIndexer(字串-索引變換)將標籤的"字串列"編碼為"標籤索引的列"。

  • 標籤索引序列的取值範圍是[0,numLabels(字串中所有出現的單詞去掉重複的詞後的總和)],按照標籤出現頻率排序,出現最多的標籤索引為0(具體為升序降序是可以配置的)。
  • 如果輸入是數值型,我們先將數值對映到字串,再對字串進行索引化。
  • 如果下游的pipeline(例如:Estimator或者Transformer)需要用到索引化後的標籤序列,則需要將這個pipeline的輸入列名字指定為索引化序列的名字。大部分情況下,通過setSelectedCols設定輸入的列名。

以這些輸入為例:

("football", "can"),
("football", "hhh"),
("football", "zzz"),
("basketball", "zzz"),
("basketball", "can"),
("tennis", "can")

對於第一列,MultiStringIndexer 對資料集的label進行重新編號。按label出現的頻次,轉換成0 ~ numOfLabels - 1(分類個數)。如果是按照從高到低排序,則頻次最高的轉換為0,以此類推,比如:

  • football,出現次數最多,出現了3次,轉換(編號)為0
  • 其次是basketball,出現了2次,編號為1,以此類推。

在應用StringIndexer對labels進行重新編號後,帶著這些編號後的label對資料進行了訓練,並接著對其他資料進行了預測,得到預測結果,預測結果的label也是重新編號過的,因此需要轉換回來。

0x02 示例程式碼

示例程式碼如下,本示例程式碼中,是按照升序排列,即football總數為3,則其idx為3,tennis個數為1,其idx為0:

public class MultiStringIndexerExample {
    static AlgoOperator getData(boolean isBatch) {
        Row[] array = new Row[] {
                Row.of("football", "can"),
                Row.of("football", "hhh"),
                Row.of("football", "zzz"),
                Row.of("basketball", "zzz"),
                Row.of("basketball", "can"),
                Row.of("tennis", "can")
        };

        if (isBatch) {
            return new MemSourceBatchOp(
                    Arrays.asList(array), new String[] {"a", "b"});
        } else {
            return new MemSourceStreamOp(
                    Arrays.asList(array), new String[] {"a", "b"});
        }
    }

    public static void main(String[] args) throws Exception {
        BatchOperator data = (BatchOperator)getData(true);
        MultiStringIndexer stringindexer = new MultiStringIndexer()
                .setSelectedCols("a", "b")
                .setOutputCols("a_indexed", "b_indexed")
                .setStringOrderType("frequency_asc");
        stringindexer.fit(data).transform(data).print();
    }
}

輸出如下:

a|b|a_indexed|b_indexed
-|-|---------|---------
football|can|2|2
football|hhh|2|0
football|zzz|2|1
basketball|zzz|1|1
basketball|can|1|2
tennis|can|0|2

轉換成表格看的更清楚。

a b a_indexed b_indexed
football can 2 2
football hhh 2 0
football zzz 2 1
basketball zzz 1 1
basketball can 1 2
tennis can 0 2

0x03 總體邏輯

我們先給出一個流程圖

老套路,我們從 MultiStringIndexerTrainBatchOp.linkFrom開始挖掘。

@Override
public MultiStringIndexerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
    BatchOperator<?> in = checkAndGetFirst(inputs);

    // 示例中有 .setSelectedCols("a", "b"),這裡是取出具體列名字
    final String[] selectedColNames = getSelectedCols();
    // 獲取列的型別
    final String[] selectedColSqlType = new String[selectedColNames.length];
    for (int i = 0; i < selectedColNames.length; i++) {
        selectedColSqlType[i] = FlinkTypeConverter.getTypeString(
            TableUtil.findColTypeWithAssertAndHint(in.getSchema(), selectedColNames[i]));
    }

// runtime列印資料
selectedColNames = {String[2]@2536} 
 0 = "a"
 1 = "b"
selectedColSqlType = {String[2]@2537} 
 0 = "VARCHAR"
 1 = "VARCHAR"
  
    // 獲取選取列對應的資料
    DataSet<Row> inputRows = in.select(selectedColNames).getDataSet();
    // 
    DataSet<Tuple3<Integer, String, Long>> indexedToken =
        StringIndexerUtil.indexTokens(inputRows, getStringOrderType(), 0L, true);

    DataSet<Row> values = indexedToken
        .mapPartition(new RichMapPartitionFunction<Tuple3<Integer, String, Long>, Row>() {
            @Override
            public void mapPartition(Iterable<Tuple3<Integer, String, Long>> values, Collector<Row> out)
                throws Exception {
                Params meta = null;
                if (getRuntimeContext().getIndexOfThisSubtask() == 0) {           
                    // 第一個task會做這個計算,就是把列名,列型別作為後設資料傳送
                    meta = new Params().set(HasSelectedCols.SELECTED_COLS, selectedColNames)
                        .set(HasSelectedColTypes.SELECTED_COL_TYPES, selectedColSqlType);
                }
 
// runtime列印資料              
meta = {Params@9311} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
 params = {HashMap@9316}  size = 2              
              
                new MultiStringIndexerModelDataConverter().save(Tuple2.of(meta, values), out);
            }
        })
        .name("build_model");

    this.setOutput(values, new MultiStringIndexerModelDataConverter().getModelSchema());
    return this;
}

訓練過程總體邏輯總結如下:

  • 取出具體列名字,列的型別;
  • 獲取"選取列"對應的資料;
  • 把列名,列型別作為後設資料傳送;
  • StringIndexerUtil.indexTokens 給各個列的不同字串賦予連續的indices。每列的 indices 彼此不相關;
    • 呼叫到 indexSortedByFreq(data, startIndex, ignoreNull, true),作用是給各個列的不同字串賦予連續的indices,indices是按照字串出現的頻率排序;
      • 呼叫到 countTokens的作用是按照 "列idx","word" 來合併計算單詞個數,得到<"列idx","word",單詞個數>,比如第一列中,football這個單詞的個數是3,則返回三元組是 <0,football,3>,其中列的idx從0開始計算。
        • 呼叫 flattenTokens 把輸入資料 Row 給打散,返回 A DataSet of tuples of column index and token,即<"列idx","word">。比如對於 Row.of("football", "can") 這個輸入,flattenTokens 輸出兩個Tuple2 ,<0, "football"> 和 <1, "can">。
        • 對上面結構進行map操作,輸出<column idx, word, 1L>,比如 <0, "football", 1L> ;
        • 按照 "列idx","word" 來分組;
        • 按照 "列idx","word" 來合併計算單詞個數;
      • indexSortedByFreq會對countTokens返回的結果<"列idx","word",詞頻>處理;
        • 首先按照 列idx 做分組;
        • 然後在上面結果基礎上,按照單詞個數排序;
        • 排序的index是以輸入引數startIndex開始,startIndex在這裡是0;
        • 最後得到 第一列的 (0,football,0),(0,basketball,1),(0,football,2);第二列的資料 (1,hhh,0),(1,zzz,1),(1,can,2);
  • 把indexTokens的結果儲存為模型,其中使用之前提到的 "把列名,列型別作為後設資料"。

下面具體剖析後兩個階段。

0x04 Add Index to Token

這部分就是給各個列的不同字串賦予連續的indices。每列的 indices 彼此不相關。

具體是由StringIndexerUtil.indexTokens 做到的。

public static DataSet<Tuple3<Integer, String, Long>> indexTokens(
    DataSet<Row> data, HasStringOrderTypeDefaultAsRandom.StringOrderType orderType,
    final long startIndex, final boolean ignoreNull) {
    		case FREQUENCY_ASC:
                return indexSortedByFreq(data, startIndex, ignoreNull, true);
}

4.1 合併計算單詞個數

indexSortedByFreq會呼叫countTokens來計算單詞個數,所以我們先看countTokens。

countTokens的作用是按照 "列idx","word" 來合併計算單詞個數,比如第一列中,football這個單詞的個數是3,則返回三元組是 <0,football,3>,其中列的idx從0開始計算。

具體邏輯如下:

  • 呼叫 flattenTokens 把輸入資料 Row 給打散,返回 A DataSet of tuples of column index and token,即<"列idx","word">。比如對於 Row.of("football", "can") 這個輸入,flattenTokens 輸出兩個Tuple2 ,<0, "football"> 和 <1, "can">。
  • 對上面結果進行map操作,輸出<column idx, word, 1L>,比如 <0, "football", 1L> ,這個是計數的常規操作。
  • 按照 "列idx","word" 來分組;
  • 按照 "列idx","word" 來合併計算單詞個數,就是不停歸併上面的 1L。

4.1.1 打散輸入資料

其中 flattenTokens 的作用是把輸入資料 Row 給打散,返回 A DataSet of tuples of column index and token.。

比如對於 Row.of("football", "can") 這個輸入,flattenTokens 使用 out.collect(Tuple2.of(i, String.valueOf(o))); 輸出兩個Tuple2。

value = {Row@9212} "football,can"
 fields = {Object[2]@9215} 
  0 = "football"
  1 = "can"
  
輸出 <0, "football"> 和 <1, "can">

4.1.2 分組計算個數

這是通過flattenTokens的結果進行 map,groupBy,reduce的一系列操作完成的。

具體程式碼如下:

public static DataSet<Tuple3<Integer, String, Long>> countTokens(DataSet<Row> data, final boolean ignoreNull) {
    return flattenTokens(data, ignoreNull) // 把輸入資料 Row 給打散
        .map(new MapFunction<Tuple2<Integer, String>, Tuple3<Integer, String, Long>>() {
            @Override
            public Tuple3<Integer, String, Long> map(Tuple2<Integer, String> value) throws Exception {
                return Tuple3.of(value.f0, value.f1, 1L); // 輸出<column idx, word, 1L>,比如 <0, "football", 1L> 
            }
        })
        .groupBy(0, 1) // 按照 "列idx","word" 來分組
        .reduce(new ReduceFunction<Tuple3<Integer, String, Long>>() {
            @Override
            public Tuple3<Integer, String, Long> reduce(Tuple3<Integer, String, Long> value1, Tuple3<Integer, String, Long> value2) throws Exception {
                value1.f2 += value2.f2;
                return value1; // 按照 "列idx","word" 來合併計算單詞個數
            }
        })
        .name("count_tokens");
}

// reduce之後發出
value1 = {Tuple3@9284} "(0,football,3)"
 f0 = {Integer@9226} 0
 f1 = "football"
 f2 = {Long@9295} 3

4.2 合併計算單詞個數

前面 countTokens的 返回三元組是 <列idx","word" ,詞頻>,其中列的idx從0開始計算。

indexSortedByFreq會對countTokens返回的結果<"列idx","word",詞頻>處理;

  • 首先按照 列idx 做分組;
  • 然後在上面結果基礎上,按照單詞個數排序;
  • 排序的index是以輸入引數startIndex開始,startIndex在這裡是0;
  • 最後得到 第一列的 (0,tennis,0),(0,basketball,1),(0,football,2);第二列的資料 (1,hhh,0),(1,zzz,1),(1,can,2);

具體程式碼如下:

public static DataSet<Tuple3<Integer, String, Long>> indexSortedByFreq(
    DataSet<Row> data, final long startIndex, final boolean ignoreNull, final boolean isAscending) {
    return countTokens(data, ignoreNull)
        .groupBy(0) //按照 列idx 做分組
        .sortGroup(2, isAscending ? Order.ASCENDING : Order.DESCENDING) //按照單詞個數排序
        .reduceGroup(new GroupReduceFunction<Tuple3<Integer, String, Long>, Tuple3<Integer, String, Long>>() {
            @Override
            public void reduce(Iterable<Tuple3<Integer, String, Long>> values,
                               Collector<Tuple3<Integer, String, Long>> out) {
                long id = startIndex;
                for (Tuple3<Integer, String, Long> value : values) {
                    out.collect(Tuple3.of(value.f0, value.f1, id++)); // 歸併
                }
            }
        });
}

0x05 輸出模型

這部分分為兩部分:

  • 輸出後設資料,就是之前得到的 "把列名,列型別作為後設資料"。
  • 輸出具體每一列的每一個單詞資訊,比如 第一列的 (0,tennis,0),(0,basketball,1),(0,football,2);第二列的資料 (1,hhh,0),(1,zzz,1),(1,can,2);
public class MultiStringIndexerModelDataConverter implements
    ModelDataConverter<Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>>, MultiStringIndexerModelData> {
    @Override
    public void save(Tuple2<Params, Iterable<Tuple3<Integer, String, Long>>> modelData, Collector<Row> collector) {
        if (modelData.f0 != null) {
            collector.collect(Row.of(-1L, modelData.f0.toJson(), null));
        }
        modelData.f1.forEach(tuple -> {
            collector.collect(Row.of(tuple.f0.longValue(), tuple.f1, tuple.f2));
        });
    }  
}

tuple = {Tuple3@9405} "(0,tennis,0)"
 f0 = {Integer@9406} 0
 f1 = "tennis"
 f2 = {Long@9408} 0

0x06 預測

預測功能是在 ModelMapperAdapter 完成的。

public class ModelMapperAdapter extends RichMapFunction<Row, Row> implements Serializable {
    private final ModelMapper mapper;
    private final ModelSource modelSource;

    @Override
    public void open(Configuration parameters) throws Exception {
        List<Row> modelRows = this.modelSource.getModelRows(getRuntimeContext());
        this.mapper.loadModel(modelRows); //載入模型
    }

    @Override
    public Row map(Row row) throws Exception {
        return this.mapper.map(row); //預測
    }
}

6.1 載入模型

MultiStringIndexerModelDataConverter中我們會進行模型載入。

  • 首先會載入元資訊
  • 其次會逐條載入模型資訊
public MultiStringIndexerModelData load(List<Row> rows) {
    MultiStringIndexerModelData modelData = new MultiStringIndexerModelData();
    modelData.tokenAndIndex = new ArrayList<>();
    modelData.tokenNumber = new HashMap<>();
    for (Row row : rows) {
        long colIndex = (Long) row.getField(0);
        if (colIndex < 0L) { // 後設資料
            modelData.meta = Params.fromJson((String) row.getField(1));
        } else { // 具體模型資訊
            int columnIndex = ((Long) row.getField(0)).intValue();
            Long tokenIndex = Long.valueOf(String.valueOf(row.getField(2)));
            modelData.tokenAndIndex.add(Tuple3.of(columnIndex, (String) row.getField(1), tokenIndex));
            modelData.tokenNumber.merge(columnIndex, 1L, Long::sum); // 合併列資料個數
        }
    }

    // To ensure that every columns has token number.
    int numFields = 0;
    if (modelData.meta != null) {
        numFields = modelData.meta.get(HasSelectedCols.SELECTED_COLS).length;
    }
    for (int i = 0; i < numFields; i++) {
        modelData.tokenNumber.merge(i, 0L, Long::sum);
    }
    return modelData;
}

最後模型內容如下,其中 tokenNumber 表示每列的資料有幾個,tokenAndIndex表示具體資訊,比如(0,tennis,0),(0,basketball,1),(0,football,2) 就表示他們都是第一列的,basketball轉換後的資料是 1:

modelData = {MultiStringIndexerModelData@9348} 
 meta = {Params@9440} "Params {selectedCols=["a","b"], selectedColTypes=["VARCHAR","VARCHAR"]}"
 tokenAndIndex = {ArrayList@9360}  size = 6
  0 = {Tuple3@9472} "(0,football,2)"
  1 = {Tuple3@9511} "(0,tennis,0)"
  2 = {Tuple3@9512} "(1,zzz,1)"
  3 = {Tuple3@9513} "(1,hhh,0)"
  4 = {Tuple3@9514} "(0,basketball,1)"
  5 = {Tuple3@9515} "(1,can,2)"
 tokenNumber = {HashMap@9385}  size = 2
  {Integer@9507} 0 -> {Long@9508} 3
  {Integer@9509} 1 -> {Long@9508} 3
numFields = 2

6.2 預測

預測是在 MultiStringIndexerModelMapper 完成的。

// 假設輸入是:row = {Row@9309} "football,can"
// 選擇的列是:selectedColNames = {String[2]@9314}  0 = "a" 1 = "b"
// 模型對映器是:
this = {MultiStringIndexerModelMapper@9309} 
 indexMapper = {HashMap@9318}  size = 2
  {Integer@9357} 0 -> {HashMap@9314}  size = 3
   key = {Integer@9357} 0
    value = 0
   value = {HashMap@9314}  size = 3
    "basketball" -> {Long@9386} 1
    "football" -> {Long@9332} 2
    "tennis" -> {Long@9384} 0
  {Integer@9352} 1 -> {HashMap@9358}  size = 3
   key = {Integer@9352} 1
    value = 1
   value = {HashMap@9358}  size = 3
    "can" -> {Long@9332} 2
    "hhh" -> {Long@9384} 0
    "zzz" -> {Long@9386} 1

則經歷過下列程式碼,最後就可以進行預測

public Row map(Row row) throws Exception {
    Row result = new Row(selectedColNames.length);
    for (int i = 0; i < selectedColNames.length; i++) {
        Map<String, Long> mapper = indexMapper.get(i);
        int colIdxInData = selectedColIndicesInData[i];
        Object val = row.getField(colIdxInData);
        String key = val == null ? null : String.valueOf(val);
        Long index = mapper.get(key);
        if (index != null) {
            result.setField(i, index); // 我們主要執行在這裡
        } else {
        }
    }
  
// 最後預測結果是:
row = {Row@9308} "football,can"
result = {Row@9313} "2,2"
    
    return outputColsHelper.getResultRow(row, result);
}

0xFF 參考

Spark之特徵預處理

相關文章