Alink漫談(七) : 如何劃分訓練資料集和測試資料集
0x00 摘要
Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文將為大家展現Alink如何劃分訓練資料集和測試資料集。
0x01 訓練資料集和測試資料集
兩分法
一般做預測分析時,會將資料分為兩大部分。一部分是訓練資料,用於構建模型,一部分是測試資料,用於檢驗模型。
三分法
但有時候模型的構建過程中也需要檢驗模型/輔助模型構建,這時會將訓練資料再分為兩個部分:1)訓練資料;2)驗證資料(Validation Data)。所以這種情況下會把資料分為三部分。
- 訓練資料(Train Data):用於模型構建。
- 驗證資料(Validation Data):可選,用於輔助模型構建,可以重複使用。
- 測試資料(Test Data):用於檢測模型構建,此資料只在模型檢驗時使用,用於評估模型的準確率。絕對不允許用於模型構建過程,否則會導致過渡擬合。
Training set是用來訓練模型或確定模型引數的,如ANN中權值等;
Validation set是用來做模型選擇(model selection),即做模型的最終優化及確定,如ANN的結構;
Test set則純粹是為了測試已經訓練好的模型的推廣能力。當然test set並不能保證模型的正確性,他只是說相似的資料用此模型會得出相似的結果。
實際應用
實際應用中,一般只將資料集分成兩類,即training set 和test set,大多數文章並不涉及validation set。我們這裡也不涉及。大家常用的sklearn的train_test_split函式就是將矩陣隨機劃分為訓練子集和測試子集,並返回劃分好的訓練集測試集樣本和訓練集測試集標籤。
0x02 Alink示例程式碼
首先我們給出示例程式碼,然後會深入剖析:
public class SplitExample {
public static void main(String[] args) throws Exception {
String url = "iris.csv";
String schema = "sepal_length double, sepal_width double, petal_length double, petal_width double, category string";
//這裡是批處理
BatchOperator data = new CsvSourceBatchOp().setFilePath(url).setSchemaStr(schema);
SplitBatchOp spliter = new SplitBatchOp().setFraction(0.8);
spliter.linkFrom(data);
BatchOperator trainData = spliter;
BatchOperator testData = spliter.getSideOutput(0);
// 這裡是流處理
CsvSourceStreamOp dataS = new CsvSourceStreamOp().setFilePath(url).setSchemaStr(schema);
SplitStreamOp spliterS = new SplitStreamOp().setFraction(0.4);
spliterS.linkFrom(dataS);
StreamOperator train_data = spliterS;
StreamOperator test_data = spliterS.getSideOutput(0);
}
}
0x03 批處理
SplitBatchOp是分割批處理的主要類,具體構建DAG的工作是在其linkFrom完成的。
總體思路比較簡單:
- 假定有一個取樣比例 fraction
- 將資料集分割槽,平行計算每個分割槽上的記錄數
- 把每個分割槽上的記錄數累積,得到所有記錄總數 totCount
- 從上而下計算出一個取樣總數:
numTarget = totCount * fraction
- 因為具體選擇元素是在每個分割槽上做的,所以在每個分割槽上,分別計算出來這個分割槽應該取樣的記錄數,比如第n個分割槽上應取樣記錄數:
task_n_count * fraction
- 把這些分割槽 "應該取樣的記錄數" 累積,得出來從下而上計算出的取樣總數:
totSelect = task_1_count * fraction + task_2_count * fraction + ... task_n_count * fraction
- numTarget 和 totSelect 可能不相等,所以隨機決定把多出來的
numTarget - totSelect
加入到某一個task中。 - 在每個task上取樣得到具體的記錄。
3.1 得到記錄數
如果要分割資料,首先必須知道資料集的記錄數。比如這個DataSet的記錄是1萬個?還是十萬個?因為資料集可能會很大,所以這一步操作也使用了並行處理,即把資料分割槽,然後通過mapPartition操作得到每一個分割槽上元素的數目。
DataSet<Tuple2<Integer, Long>> countsPerPartition = DataSetUtils.countElementsPerPartition(rows); //返回哪個task有哪些記錄數
DataSet<long[]> numPickedPerPartition = countsPerPartition
.mapPartition(new CountInPartition(fraction)) //計算總數
.setParallelism(1)
.name("decide_count_of_each_partition");
因為每個分割槽就對應了一個task,所以我們也可以認為,這是獲取了每個task的記錄數。
具體工作是在 DataSetUtils.countElementsPerPartition 中完成的。返回型別是<index of this subtask, record count in this subtask>,比如3號task擁有30個記錄。
public static <T> DataSet<Tuple2<Integer, Long>> countElementsPerPartition(DataSet<T> input) {
return input.mapPartition(new RichMapPartitionFunction<T, Tuple2<Integer, Long>>() {
@Override
public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, Long>> out) throws Exception {
long counter = 0;
for (T value : values) {
counter++; //計算本task的記錄總數
}
out.collect(new Tuple2<>(getRuntimeContext().getIndexOfThisSubtask(), counter));
}
});
}
計算總數的工作其實是在下一階段運算元中完成的。
3.2 隨機選取記錄
接下來的工作主要是在 CountInPartition.mapPartition 完成的,其作用是隨機決定每個task選擇多少個記錄。
這時候就不需要並行了,所以 .setParallelism(1)
3.2.1 得到總記錄數
得到了每個分割槽記錄數之後,我們遍歷每個task的記錄數,然後累積得到總記錄數 totCount(就是從上而下計算出來的總數)。
public void mapPartition(Iterable<Tuple2<Integer, Long>> values, Collector<long[]> out) throws Exception {
long totCount = 0L;
List<Tuple2<Integer, Long>> buffer = new ArrayList<>();
for (Tuple2<Integer, Long> value : values) { //遍歷輸入的所有分割槽記錄
totCount += value.f1; //f1是Long型別的記錄數
buffer.add(value);
}
...
//後續程式碼在下面分析。
}
3.2.2 決定每個task選擇記錄數
然後CountInPartition.mapPartition函式中會隨機決定每個task會選擇的記錄數。mapPartition的引數 Iterable<Tuple2<Integer, Long>> values 就是前一階段的結果 :一個元祖<task id, 每個task的記錄數目>。
把這些元祖結合在一起,記錄在buffer這個列表中。
buffer = {ArrayList@8972} size = 4
0 = {Tuple2@8975} "(3,38)" // 3號task,其對應的partition記錄數是38個。
1 = {Tuple2@8976} "(2,0)"
2 = {Tuple2@8977} "(0,38)"
3 = {Tuple2@8978} "(1,74)"
系統的task數目就是buffer大小。
int npart = buffer.size(); // num tasks
然後,根據”記錄總數“計算出來 “隨機訓練資料的個數numTarget”。比如總數1萬,應該隨機分配20%,於是numTarget就應該是2千。這個數字以後會用到。
long numTarget = Math.round((totCount * fraction));
得到每個task的記錄數目,比如是上面buffer中的 38,0,38,還是74,記錄在 eachCount 中。
for (Tuple2<Integer, Long> value : buffer) {
eachCount[value.f0] = value.f1;
}
得到每個task中隨機選中的訓練記錄數,記錄在 eachSelect 中。就是每個task目前 “記錄數字 * fraction”。比如3號task記錄數是38個,應該選20%,則38*20%=8個。
然後把這些task自己的“隨機訓練記錄數”再累加起來得到 totSelect(就是從下而上計算出來的總數)。
long totSelect = 0L;
for (int i = 0; i < npart; i++) {
eachSelect[i] = Math.round(Math.floor(eachCount[i] * fraction));
totSelect += eachSelect[i];
}
請注意,這時候 totSelect 和 之前計算的numTarget就有具體細微出入了,就是理論上的一個數字,但是我們 從上而下 計算 和 從下而上 計算,其結果可能不一樣。通過下面我們可以看出來。
numTarget = all count * fraction
totSelect = task_1_count * fraction + task_2_count * fraction + ...
所以我們下一步要處理這個細微出入,就得到remain,這是"總體算出來的隨機數目" numTarget 和 "從所有task選中的隨機訓練記錄數累積" totSelect 的差。
if (totSelect < numTarget) {
long remain = numTarget - totSelect;
remain = Math.min(remain, totCount - totSelect);
如果剛好個數相等,則就正常分配。
if (remain == totCount - totSelect) {
如果數目不等,隨機決定把"多出來的remain"加入到eachSelect陣列中的隨便一個記錄上。
for (int i = 0; i < Math.min(remain, npart); i++) {
int taskId = shuffle.get(i);
while (eachSelect[taskId] >= eachCount[taskId]) {
taskId = (taskId + 1) % npart;
}
eachSelect[taskId]++;
}
最後給出所有資訊
long[] statistics = new long[npart * 2];
for (int i = 0; i < npart; i++) {
statistics[i] = eachCount[i];
statistics[i + npart] = eachSelect[i];
}
out.collect(statistics);
// 我們這裡是4核,所以前面四項是eachCount,後面是eachSelect
statistics = {long[8]@9003}
0 = 38 //eachCount
1 = 38
2 = 36
3 = 38
4 = 31 //eachSelect
5 = 31
6 = 28
7 = 30
這些資訊是作為廣播變數儲存起來的,馬上下面就會用到。
.withBroadcastSet(numPickedPerPartition, "counts")
3.2.3 每個task選擇記錄
CountInPartition.PickInPartition函式中會隨機在每個task選擇記錄。
首先得到task數目 和 之前儲存的廣播變數(就是之前剛剛儲存的)。
int npart = getRuntimeContext().getNumberOfParallelSubtasks();
List<long[]> bc = getRuntimeContext().getBroadcastVariable("counts");
分離count和select。
long[] eachCount = Arrays.copyOfRange(bc.get(0), 0, npart);
long[] eachSelect = Arrays.copyOfRange(bc.get(0), npart, npart * 2);
得到總task數目
int taskId = getRuntimeContext().getIndexOfThisSubtask();
得到自己 task 對應的 count, select
long count = eachCount[taskId];
long select = eachSelect[taskId];
新增本task對應的記錄,隨機洗牌打亂順序
for (int i = 0; i < count; i++) {
shuffle.add(i); //就是把count內的數字加到陣列
}
Collections.shuffle(shuffle, new Random(taskId)); //洗牌打亂順序
// suffle舉例
shuffle = {ArrayList@8987} size = 38
0 = {Integer@8994} 17
1 = {Integer@8995} 8
2 = {Integer@8996} 33
3 = {Integer@8997} 34
4 = {Integer@8998} 20
5 = {Integer@8999} 0
6 = {Integer@9000} 26
7 = {Integer@9001} 27
8 = {Integer@9002} 23
9 = {Integer@9003} 28
10 = {Integer@9004} 9
11 = {Integer@9005} 16
12 = {Integer@9006} 13
13 = {Integer@9007} 2
14 = {Integer@9008} 5
15 = {Integer@9009} 31
16 = {Integer@9010} 15
17 = {Integer@9011} 22
18 = {Integer@9012} 18
19 = {Integer@9013} 35
20 = {Integer@9014} 36
21 = {Integer@9015} 12
22 = {Integer@9016} 7
23 = {Integer@9017} 21
24 = {Integer@9018} 14
25 = {Integer@9019} 1
26 = {Integer@9020} 10
27 = {Integer@9021} 30
28 = {Integer@9022} 29
29 = {Integer@9023} 19
30 = {Integer@9024} 25
31 = {Integer@9025} 32
32 = {Integer@9026} 37
33 = {Integer@9027} 4
34 = {Integer@9028} 11
35 = {Integer@9029} 6
36 = {Integer@9030} 3
37 = {Integer@9031} 24
隨機選擇,把選擇後的再排序回來
for (int i = 0; i < select; i++) {
selected[i] = shuffle.get(i); //這時候select看起來是按照順序選擇,但是實際上suffle裡面已經是亂序
}
Arrays.sort(selected); //這次再排序
// selected舉例,一共30個
selected = {int[30]@8991}
0 = 0
1 = 1
2 = 2
3 = 5
4 = 7
5 = 8
6 = 9
7 = 10
8 = 12
9 = 13
10 = 14
11 = 15
12 = 16
13 = 17
14 = 18
15 = 19
16 = 20
17 = 21
18 = 22
19 = 23
20 = 26
21 = 27
22 = 28
23 = 29
24 = 30
25 = 31
26 = 33
27 = 34
28 = 35
29 = 36
傳送選擇的資料
if (numEmits < selected.length && iRow == selected[numEmits]) {
out.collect(row);
numEmits++;
}
3.3 設定訓練資料集和測試資料集
output是訓練資料集,SideOutput是測試資料集。因為這兩個資料集在Alink內部都是Table型別,所以直接使用了SQL運算元 minusAll
來完成分割。
this.setOutput(out, in.getSchema());
this.setSideOutputTables(new Table[]{in.getOutputTable().minusAll(this.getOutputTable())});
0x04 流處理
訓練是在SplitStreamOp類完成的,其通過linkFrom完成了模型的構建。
流處理依賴SplitStream 和 SelectTransformation 這兩個類來完成分割流。具體並沒有建立一個物理操作,而只是影響了上游運算元如何與下游運算元聯絡,如何選擇記錄。
SplitStream <Row> splited = in.getDataStream().split(new RandomSelectorOp(getFraction()));
首先,用RandomSelectorOp來隨機決定輸出時候選擇哪個流。我們可以看到,這裡就是隨便起了"a", "b" 這兩個名字而已。
class RandomSelectorOp implements OutputSelector <Row> {
private double fraction;
private Random random = null;
@Override
public Iterable <String> select(Row value) {
if (null == random) {
random = new Random(System.currentTimeMillis());
}
List <String> output = new ArrayList <String>(1);
output.add((random.nextDouble() < fraction ? "a" : "b")); //隨機選取數字分配,隨意起的名字
return output;
}
}
其次,得到那兩個隨機生成的流。
DataStream <Row> partA = splited.select("a");
DataStream <Row> partB = splited.select("b");
最後把這兩個流分別設定為output和sideOutput。
this.setOutput(partA, in.getSchema()); //訓練集
this.setSideOutputTables(new Table[]{
DataStreamConversionUtil.toTable(getMLEnvironmentId(), partB, in.getSchema())}); //驗證集
最後返回本身,這時候SplitStreamOp擁有兩個成員變數:
this.output就是訓練集。
this.sideOutPut就是驗證集。
return this;