[Alink漫談之三] AllReduce通訊模型

羅西的思考發表於2020-05-16

[Alink漫談之三] AllReduce通訊模型

0x00 摘要

Alink 是阿里巴巴基於實時計算引擎 Flink 研發的新一代機器學習演算法平臺,是業界首個同時支援批式演算法、流式演算法的機器學習平臺。本文將帶領大家來分析Alink中通訊模型AllReduce的實現。AllReduce在Alink中應用較多,比如KMeans,LDA,Word2Vec,GD,lbfgs,Newton method,owlqn,SGD,Gbdt, random forest都用到了這個通訊模型。

因為Alink的公開資料太少,所以以下均為自行揣測,肯定會有疏漏錯誤,希望大家指出,我會隨時更新。

0x01 MPI是什麼

MPI(Message-Passing Interface)是一個跨語言的通訊協議,用於編寫平行計算,支援點對點和廣播。

MPI的目標是高效能、大規模性和可移植性。MPI在今天仍為高效能運算的主要模型。

其特點是

  • A partitioned address space 每個執行緒只能通過呼叫api去讀取非本地資料。所有的互動(Non-local Memory)都需要協同進行(握手)。

  • Supports only explicit parallelization 只支援顯性的並行化,使用者必須明確的規定訊息傳遞的方式。

AllReduce是MPI提供的一個基本原語,我們需要先了解reduce才能更好理解AllReduce。

  • 規約函式 MPI_Reduce :規約是來自函數語言程式設計的一個經典概念。其將通訊子內各程式的同一個變數參與規約計算,並向指定的程式輸出計算結果。比如通過一個函式將一批資料分成較小的一批資料。或者將一個陣列的元素通過加法函式規約為一個數字。
  • 規約並廣播函式 MPI_Allreduce :在計算規約的基礎上,將計算結果分發到每一個程式中。比如函式在得到歸約結果值之後,將結果值分發給每一個程式,這樣的話,並行中的所有程式值都能知道結果值了。

MPI_Allreduce和MPI_Reduce的一個區別就是,MPI_Reduce函式將最後的結果只傳給了指定的dest_process 號程式,而MPI_Allreduce函式可以將結果傳遞給所有的程式,因此所有的程式都能接收到結果。MPI_Allreduce函式的原型也因此不需要指定目標程式號。

AllReduce在Alink中應用較多,比如KMeans,LDA,Word2Vec,GD,lbfgs,Newton method,owlqn,SGD,Gbdt, random forest都用到了這個通訊模型。

AllReduce在演算法實現中起到了承上啟下的關鍵作用,即把原來序列跑的並行task強制打斷,把計算結果進行彙總再分發,讓序列繼續執行。有一點類似大家熟悉的併發中的Barrier。

對比Flink原生KMeans演算法,我們能看到AllReduce對應的是 groupBy(0).reduce。只有所有資料都產生之後,才能做groupBy操作。

	DataSet<Centroid> newCentroids = points
		// compute closest centroid for each point
		.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
		// count and sum point coordinates for each centroid
		.map(new CountAppender())
        // 這裡如果是Alink,就對應了AllReduce
		.groupBy(0).reduce(new CentroidAccumulator())
		// compute new centroids from point counts and coordinate sums
		.map(new CentroidAverager());

從AllReduce的註解中我們可以清晰的看出Alink實現MPI的思想。

 * An implement of {@link CommunicateFunction} that do the AllReduce.
 *
 * AllReduce is a communication primitive widely used in MPI. In our implementation, all workers do reduce on a partition of the whole data and they all get the final reduce result.
 *
 * There're mainly three stages:
 *   1. All workers send the there partial data to other workers for reduce.
 *   2. All workers do reduce on all data it received and then send partial results to others.
 *   3. All workers merge partial results into final result and put it into session context with pre-defined object name.
 */

翻譯如下:

所有的workers都在部分資料上做reduce操作,所有的workers都可以獲取到reduce最終結果
    
主要有三個階段:
1. 所有workers給其他workers傳送需要reduce的部分資料
2. 所有workers在它收到的資料上做reduce,然後把這個部分reduce的結果傳送給其他workers
3. 所有workers把部分reduce的結果合併成為最終結果,然後放入預定義的session 上下文變數中

"紙上得來終覺淺,絕知此事要躬行。"

Alink為了實現AllReduce,在背後做了大量的工作,下面我們一一剖析。

0x03 如何實現共享

共享是實現AllReduce的第一要務,因為在歸併/廣播過程中需要後設資料和輸入輸出,如果有共享變數就可以極大簡化實現。我們下面就看看Alink如何通過task manager實現共享

1. Task相關概念

  • Task(任務) : Task 是一個階段多個功能相同 subTask 的集合,類似於 Spark 中的 TaskSet。
  • subTask(子任務) :subTask 是 Flink 中任務最小執行單元,是一個 Java 類的例項,這個 Java 類中有屬性和方法,完成具體的計算邏輯。
  • 鏈式優化 : 按理說應該是每個運算元的一個並行度例項就是一個subtask。那麼,帶來很多問題,由於flink的taskmanager執行task的時候是每個task採用一個單獨的執行緒,這就會帶來很多執行緒切換開銷,進而影響吞吐量。為了減輕這種情況,flink進行了優化,也即對subtask進行鏈式操作,鏈式操作結束之後得到的task,再作為一個排程執行單元,放到一個執行緒裡執行。
  • Operator Chains(運算元鏈) :Flink 將多個 subTask 合併成一個 Task(任務),這個過程叫做 Operator Chains,每個任務由一個執行緒執行。使用 Operator Chains(運算元鏈) 可以將多個分開的 subTask 拼接成一個任務。類似於 Spark 中的 Pipeline。
  • Slot(插槽) :Flink 中計算資源進行隔離的單元,一個 Slot 中可以執行多個 subTask,但是這些 subTask 必須是來自同一個 application 的不同階段的 subTask。結果就是,每個slot可以執行job的一整個pipeline。

Flink 中的程式本質上是並行的。在執行期間,每一個運算元(Transformation)都有一個或多個運算元subTask(Operator SubTask),每個運算元的 subTask 之間都是彼此獨立,並在不同的執行緒中執行,並且可能在不同的機器或容器上執行。

同一個application,多個不同 task的 subTask,可以執行在同一個 slot 資源槽中。同一個 task 中的多個的 subTask,不能執行在一個 slot 資源槽中,他們可以分散到其他的資源槽中。對應到後面就是:AllReduceSend的多個並行度例項都不能執行在同一個slot中。

2. TaskManager

Flink 中每一個 TaskManager 都是一個JVM程式,它可能會在獨立的執行緒上執行一個或多個 subtask。TaskManager 相當於整個叢集的 Slave 節點,負責具體的任務執行和對應任務在每個節點上的資源申請和管理。

TaskManager為了對資源進行隔離和增加允許的task數,引入了slot的概念,這個slot對資源的隔離僅僅是對記憶體進行隔離,策略是均分。一個 TaskManager 至少有一個 slot。如果一個TM有N個Slot,則每個Slot分配到的Memory大小為整個TM Memory的1/N,同一個TM內的Slots只有Memory隔離,CPU是共享的。

客戶端通過將編寫好的 Flink 應用編譯打包,提交到 JobManager,然後 JobManager 會根據已註冊在 JobManager 中 TaskManager 的資源情況,將任務分配給有資源的 TaskManager節點,然後啟動並執行任務。

TaskManager 從 JobManager 接收需要部署的任務,然後使用 Slot 資源啟動 Task,建立資料接入的網路連線,接收資料並開始資料處理。同時 TaskManager 之間的資料互動都是通過資料流的方式進行的。

Flink 的任務執行其實是採用多執行緒的方式,一個TaskManager(TM)在多執行緒中併發執行多個task。這和 MapReduce 多 JVM 進行的方式有很大的區別,Flink 能夠極大提高 CPU 使用效率,在多個任務和 Task 之間通過 TaskSlot 方式共享系統資源,每個 TaskManager 中通過管理多個 TaskSlot 資源池進行對資源進行有效管理。

對應到後面就是:在一個TaskManager中間執行的多個並行的AllReduceSend例項都會共享這個TaskManager中所有靜態變數

3. 狀態共享

Alink就是利用task manager的靜態變數實現了變數共享。其中有幾個主要類和概念比較複雜。我們從上到下進行講解,能看到隨著從上到下,需要的標示和狀態逐漸增加。

3.1 概念剖析

從上往下呼叫層次如下:

演算法角度:ComContext

使用者程式碼呼叫 : context.getObj(bufferName); 這樣對使用者是最理想的,因為對於使用者來說知道變數名字就可以經過上下文來存取。

但是ComContext則需要知道更多,比如還需要知道 自己對應的sessioin和taskID,具體下面會說明。

ComContext如此向下呼叫 : SessionSharedObjs.put(objName, sessionId, taskId, obj);

框架角度:IterativeComQueue

IterativeComQueue 是一個框架概念。以Kmeans為例,就是Kmeans演算法對應了若干IterativeComQueue。

IterativeComQueue上擁有眾多compute/communicate function,每個function都應該知道自己屬於哪一個IterativeComQueue,如何和本Queue上其他function進行通訊,不能和其他Queue上搞混了。這樣就需要有一個概念來表標示這個Queue。於是就有了下面Session概念。

Session角度:SessionSharedObjs

為了區分每個IterativeComQueue,就產生了session這個概念。這樣IterativeComQueue上所有compute/communicate function都會繫結同一個session id,同一個IterativeComQueue上的所有function之間可以通訊。

一個 IterativeComQueue 對應一個session,所以<"變數名" + sessionId>就對應了這個 session 能訪問的某個變數。

SessionSharedObjs 包含靜態成員變數 :

  • int sessionId = 0; 遞增的標示,用來區分session。
  • HashMap<Tuple2<String, Integer>, Long> key2Handle。對映,表示一個session中 某個變數名 對應某個變數handle。

正常來說 "某個名字的變數" 對應 "某個變數handle" 即可。即一個session中某個變數名 對應某個變數handle。但是Flink中,會有多個subtask並行操作的狀態,這樣就需要有一個新的概念來標示subtask對應的變數,這個變數應該和taskId有所關聯。於是就有了下面的state概念。

SessionSharedObjs向下呼叫 : IterTaskObjKeeper.put(handle, taskId, obj);

Subtask角度:IterTaskObjKeeper

這裡就是用靜態變數來實現共享。是task manager中所有的 tasks (threads)都可以訪問的共享變數例項。

IterTaskObjKeeper 包含靜態成員變數 :

  • long handle = 0L; 遞增的標示,用來區分state。
  • Map <Tuple2.of(handle, taskId), state> states; 是一個對映。即handle代表哪一種變數state,<handle, taskId>表示這種變數中 "哪個task" 對應的state例項,是針對subtask的一種細分。

在Flink中,一個演算法會被多個subtask並行操作。如果只有一個handle,那麼多個subtask共同訪問,就會有大家都熟知的各種多執行緒操作問題。所以Alink這裡將handle拆分為多個state。從subtask角度看,每個state用<handle, taskId>來唯一標示。

總結一下,就是對於同樣一個變數名字,每個subtask對應的共享state其實都是獨立的,大家互不干擾。共享其實就是在這個subtask上跑的各個operator之間共享

3.2 變數例項分析

從實際執行的變數中,我們可以有一個更加清楚的認識。

// 能看出來 session 0 中,centroidAllReduce這個變數 對應的handle是 7
SessionSharedObjs.key2Handle = {HashMap@10480}  size = 9
 {Tuple2@10492} "(initCentroid,0)" -> {Long@10493} 1
 {Tuple2@10494} "(statistics,0)" -> {Long@10495} 2
 {Tuple2@10496} "(362158a2-588b-429f-b848-c901a1e15e17,0)" -> {Long@10497} 8
 {Tuple2@10498} "(k,0)" -> {Long@10499} 6
 {Tuple2@10500} "(centroidAllReduce,0)" -> {Long@10501} 7 // 這裡就是所說的
 {Tuple2@10502} "(trainData,0)" -> {Long@10503} 0
 {Tuple2@10504} "(vectorSize,0)" -> {Long@10505} 3
 {Tuple2@10506} "(centroid2,0)" -> {Long@10507} 5
 {Tuple2@10508} "(centroid1,0)" -> {Long@10509} 4

// 下面能看出來,handle 7 這一種變數,因為有 4 個subtask,所以細分為4個state。 
 com.alibaba.alink.common.comqueue.IterTaskObjKeeper.states = {HashMap@10520}  size = 36
 {Tuple2@10571} "(7,0)" -> {double[15]@10572} 
 {Tuple2@10573} "(7,1)" -> {double[15]@10574} 
 {Tuple2@10577} "(7,2)" -> {double[15]@10578} 
 {Tuple2@10581} "(7,3)" -> {double[15]@10582} 

 {Tuple2@10575} "(5,0)" -> {Tuple2@10576} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@29a72fbb)"
 {Tuple2@10579} "(5,1)" -> {Tuple2@10580} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@26c52354)"
 {Tuple2@10585} "(5,2)" -> {Tuple2@10586} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@7c6ed779)"
 {Tuple2@10588} "(5,3)" -> {Tuple2@10589} "(10,com.alibaba.alink.operator.common.distance.FastDistanceMatrixData@154b8a4d)"

下面讓我們結合程式碼,一一解析涉及的類。

3.3 ComContext

ComContext 是最上層類,用來獲取runtime資訊和共享變數。IterativeComQueue(BaseComQueue )上所有的compute/communicate function都通過 ComContext 來訪問共享變數。比如:

public class BaseComQueue<Q extends BaseComQueue<Q>> implements Serializable {

    // 每一個BaseComQueue都會得到唯一一個sessionId。
    private final int sessionId = SessionSharedObjs.getNewSessionId();

    int taskId = getRuntimeContext().getIndexOfThisSubtask();

    public void mapPartition(Iterable<byte[]> values, Collector<byte[]> out) {
        // 獲取到了一個ComContext
        ComContext context = new ComContext(sessionId, getIterationRuntimeContext());
        if (getIterationRuntimeContext().getSuperstepNumber() == maxIter || criterion) {
            // 利用ComContext繼續訪問共享變數
            List<Row> model = completeResult.calc(context);
        }
    }
}

// 使用者類似這麼呼叫

double[] sendBuf = context.getObj(bufferName);

可以看出來,ComContext 就是使用者應該看到的最頂層上下文概念。 taskId, sessionId 是使用關鍵

  • sessionId 是在 SessionSharedObjs中定義的靜態類成員變數,其會自動遞增。每一個BaseComQueue都會得到唯一一個sessionId,即該Queue保持了唯一session。這樣BaseComQueue中生成的ComContext都有相同的sessionId。
  • taskId是從runtime中獲得。
/**
 * Encapsulates task-specific information: name, index of subtask, parallelism and attempt number.
 */
@Internal
public class TaskInfo {
	/**
	 * Gets the number of this parallel subtask. The numbering starts from 0 and goes up to parallelism-1 (parallelism as returned by {@link #getNumberOfParallelSubtasks()}).
	 *
	 * @return The index of the parallel subtask.
	 */
	public int getIndexOfThisSubtask() {
		return this.indexOfSubtask; // 這裡獲取taskId
	}
}

ComContext 具體類定義如下

/**
 * Context used in BaseComQueue to access basic runtime information and shared objects.
 */
public class ComContext {
	private final int taskId;
	private final int numTask;
	private final int stepNo;
	private final int sessionId;

	public ComContext(int sessionId, IterationRuntimeContext runtimeContext) {
		this.sessionId = sessionId;
		this.numTask = runtimeContext.getNumberOfParallelSubtasks();
		this.taskId = runtimeContext.getIndexOfThisSubtask();
		this.stepNo = runtimeContext.getSuperstepNumber();
	}
    
	/**
	 * Put an object into shared objects for access of other QueueItem of the same taskId.
	 *
	 * @param objName object name
	 * @param obj     object itself.
	 */
	public void putObj(String objName, Object obj) {
		SessionSharedObjs.put(objName, sessionId, taskId, obj);
	}
}

// 比如具體舉例如下
this = {ComContext@10578} 
 taskId = 4
 numTask = 8
 stepNo = 1
 sessionId = 0

3.4 SessionSharedObjs

SessionSharedObjs是再下一層的類,維護shared session objects, 這個session 共享是通過 sessionId 做到的。

SessionSharedObjs 維護了一個靜態類變數 sessionId,由此區分各個Session。

SessionSharedObjs核心是 HashMap<Tuple2<String, Integer>, Long> key2Handle 。即 <"變數名" + sessionId> ---> <真實變數 handle> 的一個對映。

一個 IterativeComQueue 對應一個session,所以<"變數名" + sessionId>就對應了這個 IterativeComQueue 能訪問的某個變數,正常來說有一個變數handle即可。

但是因為一個 IterativeComQueue會被若干subtask並行執行,所以為了互斥和區分,所以每個handle又細分為若干state,每個state用<handle, taskId>來唯一標示。在下面會提到。

/**
 * An static class that manage shared objects for {@link BaseComQueue}s.
 */
class SessionSharedObjs implements Serializable {
	private static HashMap<Tuple2<String, Integer>, Long> key2Handle = new HashMap<>();
	private static int sessionId = 0;
	private static ReadWriteLock rwlock = new ReentrantReadWriteLock();
    
	/**
	 * Get a new session id.
	 * All access operation should bind with a session id. This id is usually shared among compute/communicate function of an {@link IterativeComQueue}.
	 *
	 * @return new session id.
	 */
	synchronized static int getNewSessionId() {
		return sessionId++;
	}    
    
	static void put(String objName, int session, int taskId, Object obj) {
		rwlock.writeLock().lock();
		try {
			Long handle = key2Handle.get(Tuple2.of(objName, session));
			if (handle == null) {
				handle = IterTaskObjKeeper.getNewHandle();
				key2Handle.put(Tuple2.of(objName, session), handle);
			}
      // 這裡進行呼叫。taskId也是辨識關鍵。
			IterTaskObjKeeper.put(handle, taskId, obj);
		} finally {
			rwlock.writeLock().unlock();
		}
	}    
}

3.5 IterTaskObjKeeper

這是最底層的共享類,是在task manager程式的堆記憶體上的一個靜態例項。task manager的所有task (threads) 都可以分享。

看原始碼可知,IterTaskObjKeeper 是通過一個靜態變數states實現了在整個JVM內共享。而具體內容是由 'handle' and 'taskId' 來共同決定。

IterTaskObjKeeper維持了 handle 遞增來作為 “變數state” 的唯一種類標識

用<handle, taskId>來作為“變數state”的唯一標識。這個就是在 task manager process 堆記憶體中被大家共享的變數。

即handle代表哪一種變數state,<handle, taskId>表示這種變數中,對應哪一個task的哪一個變數。 這是針對task的一種細分。

/**
 * A 'state' is an object in the heap memory of task manager process,
 * shared across all tasks (threads) in the task manager.

 * Note that the 'state' is shared by all tasks on the same task manager,
 * users should guarantee that no two tasks modify a 'state' at the same time.

 * A 'state' is identified by 'handle' and 'taskId'.
 */
public class IterTaskObjKeeper implements Serializable {
	private static Map <Tuple2 <Long, Integer>, Object> states;

	/**
	 * A 'handle' is a unique identifier of a state.
	 */
	private static long handle = 0L;

	private static ReadWriteLock rwlock = new ReentrantReadWriteLock();

	static {
		states = new HashMap <>();
	}

	/**
	 * @note Should get a new handle on the client side and pass it to transformers.
	 */
	synchronized public static long getNewHandle() {
		return handle++;
	}

	public static void put(long handle, int taskId, Object state) {
		rwlock.writeLock().lock();
		try {
			states.put(Tuple2.of(handle, taskId), state); 
		} finally {
			rwlock.writeLock().unlock();
		}
	}
}

0x04. 示例程式碼

我們示例程式碼依然如下。

KMeansTrainBatchOp呼叫

	static DataSet <Row> iterateICQ(...省略...) {
		return new IterativeComQueue()
			.initWithPartitionedData(TRAIN_DATA, data)
			.initWithBroadcastData(INIT_CENTROID, initCentroid)
			.initWithBroadcastData(KMEANS_STATISTICS, statistics)
			.add(new KMeansPreallocateCentroid())
			.add(new KMeansAssignCluster(distance))
			.add(new AllReduce(CENTROID_ALL_REDUCE))
			.add(new KMeansUpdateCentroids(distance))
			.setCompareCriterionOfNode0(new KMeansIterTermination(distance, tol))
			.closeWith(new KMeansOutputModel(distanceType, vectorColName, latitudeColName, longitudeColName))
			.setMaxIter(maxIter)
			.exec();
	}

AllReduce實現

Alink的AllReduce主要程式碼摘取如下:

public static <T> DataSet <T> allReduce(
    return input
		.mapPartition(new AllReduceSend <T>(bufferName, lengthName, transferBufferName, sessionId))
		.withBroadcastSet(input, "barrier")
		.returns(
			new TupleTypeInfo <>(Types.INT, Types.INT, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))
		.name("AllReduceSend")
		.partitionCustom(new Partitioner <Integer>() {
			@Override
			public int partition(Integer key, int numPartitions) {
				return key;
			}
		}, 0)
		.name("AllReduceBroadcastRaw")
		.mapPartition(new AllReduceSum(bufferName, lengthName, sessionId, op))
		.returns(
			new TupleTypeInfo <>(Types.INT, Types.INT, PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO))
		.name("AllReduceSum")
		.partitionCustom(new Partitioner <Integer>() {
			@Override
			public int partition(Integer key, int numPartitions) {
				return key;
			}
		}, 0)
		.name("AllReduceBroadcastSum")
		.mapPartition(new AllReduceRecv <T>(bufferName, lengthName, sessionId))
		.returns(input.getType())
		.name("AllReduceRecv");
}

0x05 AllReduce實現

結合上面具體程式碼,我們先總結AllReduce使用流程如下

  • KMeansAssignCluster :Find the closest cluster for every point and calculate the sums of the points belonging to the same cluster。然後把自己計算出來的cluster 寫入到自己 task manager 的 CENTROID_ALL_REDUCE。

  • 每個AllReduceSend 從自己task manager的CENTROID_ALL_REDUCE中取出之前存入的 cluster(每個AllReduceSend獲取的cluster都是隻有自己能看到的),然後傳送給下游task。傳送時根據 "下游task index 和 資料量" 來決定往哪些task傳送。這裡要注意的是:具體給哪一個task傳送變數的哪一部分,是依據那個task 的 task index 和資料量 來計算出來的。這個計算機制(如何計算在程式碼中,也有部分作為元資訊隨著資料一起傳送)被後面的AllReduceRecv複用。

  • 每個 AllReduceSum 接收到 AllReduceSend 傳送過來的 cluster,計算求和,然後把計算結果再傳送出去。每一個AllReduceSum 都是把自己計算求和出來的資料統一發給每一個下游task。

  • 每個 AllReduceRecv 接收到 所有 AllReduceSum 傳送過來的(求和之後的)cluster。存入到共享變數CENTROID_ALL_REDUCE。具體如何存就複用AllReduceSend計算機制,這樣存到共享變數的什麼地方就互相不會衝突。可以理解為merge操作:比如有5個AllReduce,每個AllReduce的資料都發給了5個AllReduceRecv,每個AllReduceRecv接到這5份資料之後,會根據自己的subtask index寫到自己對應的state中,但是這5份資料分別寫在state什麼地方都是在資料元資訊中指定的,彼此不會有寫的衝突,這樣每個AllReduceRecv就擁有了全部5份資料。

  • KMeansUpdateCentroids :取出CENTROID_ALL_REDUCE變數,然後Update the centroids based on the sum of points and point number belonging to the same cluster

1. KMeansAssignCluster

該類的作用是:為每個點(point)計算最近的聚類中心,為每個聚類中心的點座標的計數和求和。

我們可以看出,KMeansAssignCluster 通過ComContext儲存了CENTROID_ALL_REDUCE,為後續AllReduce使用。假如有5個KMeansAssignCluster,則他們計算出來的結果一般來說各不相同。雖然儲存同一個變數名CENTROID_ALL_REDUCE,但是其state各不相同。

因為這5個KMeansAssignCluster勢必對應了5個subtask,則其在共享變數中的<handle, taskId>必不相同,則對應不同的state,所以分開儲存。

// Find the closest cluster for every point and calculate the sums of the points belonging to the same cluster.
public class KMeansAssignCluster extends ComputeFunction {
        // 存取共享變數
        double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);
        if (sumMatrixData == null) {
            sumMatrixData = new double[k * (vectorSize + 1)];
            context.putObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE, sumMatrixData);
        }  
    
        for (FastDistanceVectorData sample : trainData) {
            // Find the closest centroid from centroids for sample, and add the sample to sumMatrix.
            KMeansUtil.updateSumMatrix(sample, 1, stepNumCentroids.f1, vectorSize, sumMatrixData, k, fastDistance, distanceMatrix);
        }    
}

// 程式中各個變數如下

sample = {FastDistanceVectorData@13274} 
 vector = {DenseVector@13281} "6.3 2.5 4.9 1.5"
 label = {DenseVector@13282} "72.2"
 rows = {Row[1]@13283} 

// 這個就是共享變數。4維向量 + 1 weight ---> 都是"sample和"。
sumMatrixData = {double[15]@10574} 
 0 = 23.6
 1 = 14.9
 2 = 8.7
 3 = 1.7000000000000002
 4 = 5.0
 5 = 52.400000000000006
 6 = 25.1
 7 = 39.699999999999996
 8 = 13.299999999999999
 9 = 9.0
 10 = 33.0
 11 = 16.9
 12 = 28.900000000000002
 13 = 11.4
 14 = 5.0
     
trainData = {ArrayList@10580}  size = 19
 0 = {FastDistanceVectorData@10590} 
  vector = {DenseVector@10595} "7.7 3.8 6.7 2.2"
   data = {double[4]@10601} 
    0 = 7.7
    1 = 3.8
    2 = 6.7
    3 = 2.2
  label = {DenseVector@10596} "123.46000000000001"
  rows = {Row[1]@10597} 
 1 = {FastDistanceVectorData@10603} 
  vector = {DenseVector@10623} "5.7 2.8 4.1 1.3"
  label = {DenseVector@10624} "58.83"
  rows = {Row[1]@10625} 
 2 = {FastDistanceVectorData@10604} 
 3 = {FastDistanceVectorData@10605} 
......
 17 = {FastDistanceVectorData@10619} 
 18 = {FastDistanceVectorData@10620} 
  vector = {DenseVector@10654} "6.5 3.0 5.2 2.0"
  label = {DenseVector@10655} "82.29"
  rows = {Row[1]@10656}      

2. AllReduceSend

這裡需要再把程式碼摘錄一遍,主要是因為有withBroadcastSet。其作用是:

  • 可以理解為是一個公共的共享變數,我們可以把一個dataset 資料集廣播出去,然後不同的task在節點上都能夠獲取到,這個資料在每個節點上只會存在一份。
  • 如果不使用broadcast,則在每個節點中的每個task中都需要拷貝一份dataset資料集,比較浪費記憶體(也就是一個節點中可能會存在多份dataset資料)。
		return input
			.mapPartition(new AllReduceSend <T>(bufferName, lengthName, transferBufferName, sessionId))
			.withBroadcastSet(input, "barrier")

KMeansAssignCluster 會往上下文的變數centroidAllReduce中新增資料。所以 AllReduce 其實就是在等待這個變數。

AllReduce的第一步就是從上下文中取出共享變數,然後傳送。這部分程式碼由AllReduceSend完成。

對於AllReduceSend的每個task來說,bufferName都是 centroidAllReduce。

因為每個AllReduceSend也對應不同的task,所以每個AllReduceSend讀取的centroidAllReduce必然不一樣,所以每個task獲取的sendBuf都不一樣。他們分別把自己<handle, taskId>對應的 "centroidAllReduce" state取出,傳送給下游。

AllReduceSend 發給其下游時候,是以subtask的序號為基準傳送給每一個task,即本task中獲取的共享變數會傳送給每一個task,但是具體給哪一個task傳送變數的那一部分,是依據那個task 的 task index 和資料量 來計算出來的。如果資料量少,可能只給某一個或者幾個task傳送。

後續中的 taskId ,都是subtask id。

其中,如何計算給哪個task傳送多少,是在DefaultDistributedInfo完成的。這裡需要結合 pieces 函式進行分析。需要注意的是:AllReduceSend這麼傳送,AllReduceRecv後面也按照這個套路接受。這樣AllReduceRecv就可以merge了

AllReduceSend這麼傳送,AllReduceRecv後面也按照這個套路接受

int pieces = pieces(sendLen);//表示本人這次send的資料分成幾片,比如分成50片。每片大小是TRANSFER_BUFFER_SIZE

// 將要發給 8 個 subtask
for (int i = 0; i < numOfSubTasks; ++i) {
      // 假如第5個subtask,那麼它傳送的起始位置就是50/8 * 4
      int startPos = (int) distributedInfo.startPos(i, numOfSubTasks, pieces);
      // 給第5個subtask傳送多少片
      int cnt = (int) distributedInfo.localRowCnt(i, numOfSubTasks, pieces);

具體程式碼如下:

	private static int pieces(int len) {
		int div = len / TRANSFER_BUFFER_SIZE; //本人這次send的資料分成幾片,每片大小是TRANSFER_BUFFER_SIZE
		int mod = len % TRANSFER_BUFFER_SIZE;

		return mod == 0 ? div : div + 1;
	}

public class DefaultDistributedInfo implements DistributedInfo {

	public long startPos(long taskId, long parallelism, long globalRowCnt) {
		long div = globalRowCnt / parallelism;
		long mod = globalRowCnt % parallelism;

		if (mod == 0) {
			return div * taskId;
		} else if (taskId >= mod) {
			return div * taskId + mod;
		} else {
			return div * taskId + taskId;
		}
	}
    
	public long localRowCnt(long taskId, long parallelism, long globalRowCnt) {
		long div = globalRowCnt / parallelism;
		long mod = globalRowCnt % parallelism;

		if (mod == 0) {
			return div;
		} else if (taskId >= mod) {
			return div;
		} else {
			return div + 1;
		}
	}     
}

具體AllReduceSend程式碼如下,註解中有詳細說明。

// 這裡是變數名字定義。	
public static final String CENTROID_ALL_REDUCE = "centroidAllReduce";

private static class AllReduceSend<T> extends RichMapPartitionFunction <T, Tuple3 <Integer, Integer, double[]>> {
        
    	int numOfSubTasks = getRuntimeContext().getNumberOfParallelSubtasks();
		// 與並行度相關,每個task都會執行相同操作
		// bufferName都是 centroidAllReduce,每個task獲取的sendBuf都不一樣
    
        // 計算怎麼傳送所需要的資料結構
    	int pieces = pieces(sendLen);
    	DistributedInfo distributedInfo = new DefaultDistributedInfo();

        // 從上下文中獲取需要傳送的資料
		double[] sendBuf = context.getObj(bufferName);
        
			int agg = 0;
    		// 可以看出來,是把需要傳送的資料給每個task都傳送。當然這個傳送是根據傳送資料的大小來確定的,如果資料量小,可能就只給一個或者幾個task傳送。
			for (int i = 0; i < numOfSubTasks; ++i) {
                // startPos : 具體傳送變數的那一部分,是依據task index來決定的。
                // cnt : 具體哪一個下游 task i 傳送多少資料由此決定,如果是0,就不給task i傳送資料。
				int startPos = (int) distributedInfo.startPos(i, numOfSubTasks, pieces);
				int cnt = (int) distributedInfo.localRowCnt(i, numOfSubTasks, pieces);

				for (int j = 0; j < cnt; ++j) {
                    // 傳送哪一個部分
					int bufStart = (startPos + j) * TRANSFER_BUFFER_SIZE;
					// the last
					if (startPos + j == pieces - 1) {
						System.arraycopy(sendBuf, bufStart, transBuf, 0, lastLen(sendLen));
					} else {
						System.arraycopy(sendBuf, bufStart, transBuf, 0, TRANSFER_BUFFER_SIZE);
					}
					agg++;
                    
          // i 是subTasks的index,startPos + j是buffer內的位置,後續分割槽實際就是按照這個 i 來分割槽的。本AllReduceSend就是傳送到numOfSubTasks這些task中。
					out.collect(Tuple3.of(i, startPos + j, transBuf));
				}
			}
}

	private static int pieces(int len) {
		int div = len / TRANSFER_BUFFER_SIZE; // 4096
		int mod = len % TRANSFER_BUFFER_SIZE;
		return mod == 0 ? div : div + 1;
	}

sendBuf = {double[15]@10602} 
 0 = 40.3
 1 = 18.200000000000003
 2 = 33.6
 3 = 12.5
 4 = 6.0
 5 = 45.3
 6 = 30.599999999999998
 7 = 12.4
 8 = 2.0
 9 = 9.0
 10 = 24.0
 11 = 10.4
 12 = 17.1
 13 = 5.199999999999999
 14 = 4.0

this = {AllReduce$AllReduceSend@10598} 
 bufferName = "centroidAllReduce"
 lengthName = null
 transferBufferName = "3dfb2aae-683d-4497-91fc-30b8d6853bce"
 sessionId = 0
 runtimeContext = {AbstractIterativeTask$IterativeRuntimeUdfContext@10606}       

3. AllReduceBroadcastRaw

AllReduceSend傳送變數給下游時候,使用了自定義的partition(partitionCustom )。其是用 index of subtask 來作為key分割槽。這樣就和AllReduceSend那個out.collect對應了。

			.partitionCustom(new Partitioner <Integer>() {
				@Override
				public int partition(Integer key, int numPartitions) {
					return key;
				}
			}, 0)
			.name("AllReduceBroadcastRaw")
               
// 呼叫到這個partition函式的呼叫棧
                
partition:102, AllReduce$2 (com.alibaba.alink.common.comqueue.communication)
partition:99, AllReduce$2 (com.alibaba.alink.common.comqueue.communication)
customPartition:235, OutputEmitter (org.apache.flink.runtime.operators.shipping)
selectChannel:149, OutputEmitter (org.apache.flink.runtime.operators.shipping)
selectChannel:36, OutputEmitter (org.apache.flink.runtime.operators.shipping)
emit:120, RecordWriter (org.apache.flink.runtime.io.network.api.writer)
collect:65, OutputCollector (org.apache.flink.runtime.operators.shipping)
collect:35, CountingCollector (org.apache.flink.runtime.operators.util.metrics)
mapPartition:257, AllReduce$AllReduceSend (com.alibaba.alink.common.comqueue.communication)
run:103, MapPartitionDriver (org.apache.flink.runtime.operators)
run:504, BatchTask (org.apache.flink.runtime.operators)
run:157, AbstractIterativeTask (org.apache.flink.runtime.iterative.task)
run:107, IterationIntermediateTask (org.apache.flink.runtime.iterative.task)
invoke:369, BatchTask (org.apache.flink.runtime.operators)
doRun:705, Task (org.apache.flink.runtime.taskmanager)
run:530, Task (org.apache.flink.runtime.taskmanager)
run:745, Thread (java.lang)                  
                
                 
 // @AllReduceSend.mapPartition 這裡開始呼叫   
 for (int i = 0; i < numOfSubTasks; ++i) {   
     // i 是subTasks的index,後續分割槽實際就是按照這個 i 來分割槽的。本AllReduceSend就是傳送到numOfSubTasks這些task中。
	 out.collect(Tuple3.of(i, startPos + j, transBuf));     
 }
                
 // 從後續呼叫序列可以看出來,最終是用 index of subtask 來作為key分割槽。    

// 這裡傳送record

 public class CountingCollector<OUT> implements Collector<OUT> {
	public void collect(OUT record) {
		this.numRecordsOut.inc();
		this.collector.collect(record);
	}     
 }
             
 record = {Tuple3@10586} "(0,0,[40.50000000000001, 18.7, 33.300000000000004, 12.8, 6.0, 29.7, 21.0, 8.4, 1.7, 6.0, 48.1, 22.199999999999996, 36.0, 12.200000000000001, 8.0, 0.0,"
 f0 = {Integer@10583} 0
 f1 = {Integer@10583} 0
 f2 = {double[4096]@10598}                
       
// 這裡開始分割槽

public class OutputEmitter<T> implements ChannelSelector<SerializationDelegate<T>> {
	private int customPartition(T record, int numberOfChannels) {
		if (extractedKeys == null) {
			extractedKeys = new Object[1];
		}

		if (comparator.extractKeys(record, extractedKeys, 0) == 1) {
            // 所以 key 是 0
			final Object key = extractedKeys[0];
			return partitioner.partition(key, numberOfChannels);
		}            
	}    
}

public final class TupleComparator<T extends Tuple> extends TupleComparatorBase<T> {
	public int extractKeys(Object record, Object[] target, int index) {
		int localIndex = index;
		for(int i = 0; i < comparators.length; i++) {
			localIndex += comparators[i].extractKeys(((Tuple) record).getField(keyPositions[i]), target, localIndex);
		}
		return localIndex - index;
	}    
}

// 就是取出第一個field的數值

key = {Integer@10583} 0
 value = 0
    
extractedKeys = {Object[1]@10587} 
 0 = {Integer@10583} 0
  value = 0

4. AllReduceSum

所有workers在它收到的資料上做reduce,然後把這個部分reduce的結果(partial results)傳送給其他workers。

partial results是因為每個task接受的資料不同,是上游根據task index計算位置並且傳送過來的。

但是AllReduceSum的計算結果會給每一個下游 task index 傳送

private static class AllReduceSum extends RichMapPartitionFunction <Tuple3 <Integer, Integer, double[]>, Tuple3 <Integer, Integer, double[]>> {
    
    	public void mapPartition(Iterable <Tuple3 <Integer, Integer, double[]>> values,Collector <Tuple3 <Integer, Integer, double[]>> out) {
            
            // 這時候雖然也用到了context取出了sendBuf,但是隻是用來獲取其長度而已。
    		int taskId = getRuntimeContext().getIndexOfThisSubtask();
			int numOfSubTasks = getRuntimeContext().getNumberOfParallelSubtasks();

			double[] sendBuf = context.getObj(bufferName);
			int sendLen = lengthName != null ? context.getObj(lengthName) : sendBuf.length;
			int pieces = pieces(sendLen);
			DistributedInfo distributedInfo = new DefaultDistributedInfo();

            // startPos : 本task接受的資料,startPos 是應該從原始資料的哪個位置開始。是依據task index來決定的。
            // cnt : 具體哪一個下游 task i 傳送多少資料由此決定。   
			int startPos = (int) distributedInfo.startPos(taskId, numOfSubTasks, pieces);
			int cnt = (int) distributedInfo.localRowCnt(taskId, numOfSubTasks, pieces);
    
    		// 這裡進行了reduce SUM工作
			double[][] sum = new double[cnt][];
			double[] agg = new double[cnt];
			do {
				Tuple3 <Integer, Integer, double[]> val = it.next();
				int localPos = val.f1 - startPos;
				if (sum[localPos] == null) {
					sum[localPos] = val.f2;
					agg[localPos]++;
				} else {
					op.accept(sum[localPos], val.f2);
				}
			} while (it.hasNext());    
    
    		// 依然傳送給下游,依然是用subtask index來作為partition key。
            // 注意,這裡是把結果傳送給所有的下游task。
			for (int i = 0; i < numOfSubTasks; ++i) {
				for (int j = 0; j < cnt; ++j) {
          // startPos是本task傳送的資料應該從原始資料的哪個位置開始。
          // 但是給每一個 task i 發的都是同樣的資料。但是 startPos + j 很重要,下游task i 會根據這個知道它應該把接收到的資料儲存在預定義變數的什麼地方。
					out.collect(Tuple3.of(i, startPos + j, sum[j]));
				}
			}   
        }
}

sum = {double[1][]@10605} 
 0 = {double[4096]@10613} 
  0 = 118.50000000000001
  1 = 77.7
  2 = 37.2
  3 = 5.9
  4 = 25.0
  5 = 621.1000000000001
  6 = 284.7
  7 = 487.59999999999997
  8 = 166.5
  9 = 99.0
  10 = 136.9
  11 = 95.7
  12 = 39.0
  13 = 7.4
  14 = 26.0

5. AllReduceBroadcastSum

AllReduceSum 傳送變數給下游時候,使用了自定義的partition(partitionCustom )。其是用 index of subtask 來作為key分割槽。

其意義和之前的 partitionCustom 相同。

6. AllReduceRecv

All workers merge partial results into final result and put it into session context with pre-defined object name.

每一個下游 AllReduceRecv 都接收到 每一個上游 AllReduceSum 傳送過來的 cluster(求和之後的),然後把每份資料存入到自己task manager對應的預定義變數state的不同部分(這個不同部分是根據接受到的資料val.f1計算出來的)。

結合前面可知,AllReduceSend傳送和AllReduceRecv接受,都是按照同樣的套路計算在共享變數中的資料位置。這樣AllReduceRecv就可以merge了。

這樣就完成了所有workers把部分reduce sum的結果合併成為最終結果,然後放入預定義的上下文變數中。

	private static class AllReduceRecv<T> extends RichMapPartitionFunction <Tuple3 <Integer, Integer, double[]>, T> {
		private final String bufferName;
		private final String lengthName;
		private final int sessionId;

		@Override
		public void mapPartition(Iterable <Tuple3 <Integer, Integer, double[]>> values, Collector <T> out) throws Exception {
			ComContext context = new ComContext(sessionId, getIterationRuntimeContext());
			Iterator <Tuple3 <Integer, Integer, double[]>> it = values.iterator();
			if (!it.hasNext()) {
				return;
			}
			double[] recvBuf = context.getObj(bufferName);
			int recvLen = lengthName != null ? context.getObj(lengthName) : recvBuf.length;
			int pieces = pieces(recvLen); // 和之前AllReduceSend一樣的套路計算應該儲存在共享變數什麼位置。
			do {
				Tuple3 <Integer, Integer, double[]> val = it.next();
				if (val.f1 == pieces - 1) {
					System.arraycopy(val.f2, 0, recvBuf, val.f1 * TRANSFER_BUFFER_SIZE, lastLen(recvLen));
				} else {
           // 拷貝到共享變數的相應部位。val.f1 是上游傳送過來的。作為merge功能的起始位置。
					System.arraycopy(val.f2, 0, recvBuf, val.f1 * TRANSFER_BUFFER_SIZE, TRANSFER_BUFFER_SIZE);
				}
			} while (it.hasNext());
		}
	}

val = {Tuple3@10672} "(3,0,[335.3, 150.89999999999998, 277.5, 99.79999999999998, 50.0, 290.9, 136.3, 213.1, 67.8, 50.0, 250.3, 170.89999999999998, 73.2, 12.2, 50.0, 0.0....."
 f0 = {Integer@10682} 3
  value = 3
 f1 = {Integer@10638} 0
  value = 0
 f2 = {double[4096]@10674} 
  0 = 335.3
  1 = 150.89999999999998
  2 = 277.5
  3 = 99.79999999999998
  4 = 50.0
  5 = 290.9
  6 = 136.3
  7 = 213.1
  8 = 67.8
  9 = 50.0
  10 = 250.3
  11 = 170.89999999999998
  12 = 73.2
  13 = 12.2
  14 = 50.0
  15 = 0.0
  ......
      
// 每個task都收到了reduce sum結果。      
recvBuf = {double[15]@10666} 
 0 = 404.3
 1 = 183.1
 2 = 329.3
 3 = 117.2
 4 = 61.0
 5 = 250.3
 6 = 170.89999999999998
 7 = 73.20000000000002
 8 = 12.2
 9 = 50.0
 10 = 221.89999999999998
 11 = 104.1
 12 = 161.29999999999998
 13 = 50.4
 14 = 39.0      
      

7. KMeansUpdateCentroids

基於點計數和座標,計算新的聚類中心。這裡就是從task manager中取出了AllReduce儲存的共享變數CENTROID_ALL_REDUCE。

/**
 * Update the centroids based on the sum of points and point number belonging to the same cluster.
 */
public class KMeansUpdateCentroids extends ComputeFunction {
    public void calc(ComContext context) {

        Integer vectorSize = context.getObj(KMeansTrainBatchOp.VECTOR_SIZE);
        Integer k = context.getObj(KMeansTrainBatchOp.K);

        // 這裡取出AllReduce儲存的共享變數
        double[] sumMatrixData = context.getObj(KMeansTrainBatchOp.CENTROID_ALL_REDUCE);

        Tuple2<Integer, FastDistanceMatrixData> stepNumCentroids;
        if (context.getStepNo() % 2 == 0) {
            stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID2);
        } else {
            stepNumCentroids = context.getObj(KMeansTrainBatchOp.CENTROID1);
        }

        stepNumCentroids.f0 = context.getStepNo();

        context.putObj(KMeansTrainBatchOp.K,
            updateCentroids(stepNumCentroids.f1, k, vectorSize, sumMatrixData, distance));
    }
}

0x06 參考

我的平行計算之路(四)MPI集合通訊之Reduce和Allreduce

Message Passing Interface(MPI)

Flink 之 Dataflow、Task、subTask、Operator Chains、Slot 介紹

Flink執行時之TaskManager執行Task

相關文章