[原始碼解析] PyTorch 分散式 Autograd (1) ---- 設計
0x00 摘要
本文以幾篇PyTorch官方文件為基礎來了解分散式 autograd 的設計和內部結構,在翻譯時並沒有逐字翻譯,其中加入了自己的部分理解。分散式 autograd 後續文章的分析也會基於本文進行。
PyTorch分散式其他文章如下:
[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀
[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)
[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)
[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現
[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎
[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構
[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯
[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法
[原始碼解析] PyTorch 分散式(1)------歷史和概述
[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)
[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)
[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念
[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用
[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store
[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組
[原始碼解析] PyTorch 分散式(8) -------- DistributedDataParallel之論文篇
[原始碼解析] PyTorch 分散式(9) ----- DistributedDataParallel 之初始化
[原始碼解析] PyTorch 分散式(10)------DistributedDataParallel 之 Reducer靜態架構
[原始碼解析] PyTorch 分散式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作
[原始碼解析] PyTorch 分散式(12) ----- DistributedDataParallel 之 前向傳播
[原始碼解析] PyTorch 分散式(13) ----- DistributedDataParallel 之 反向傳播
0x01 分散式RPC框架
本文主要以 https://pytorch.org/docs/master/rpc/distributed_autograd.html 為基準,但是原文件要求使用者熟悉 Autograd 機制和分散式 RPC 框架,因為我們已經分析過 Autograd 機制,所以我們先研究一下 分散式 RPC 框架。
1.1 RPC 框架
RPC(Remote Procedure Call)是一種設計或者技術思想,而不是協議或者規範。
對於 RPC 最簡單的理解就是一個節點請求另外一個節點所提供的服務,但是對於使用者程式碼來說需要維護一個"本地呼叫"的感覺,即,對於遠端函式呼叫需要像呼叫本地的函式一樣,遠端服務或者程式碼看起來像執行在本地。
RPC 需要解決幾個問題:
- 如何通訊:即如何在呼叫者和服務提供者之間建立連線。
- 如何定址:即呼叫者如何找到服務提供者,怎麼知道其中有什麼服務。
- 如何傳送引數:呼叫者發起遠端呼叫時候,方法的引數需要通過 TCP 等協議傳輸到伺服器,引數如何序列化?
- 如何接受引數:服務提供者收到引數之後如何反序列化,如何呼叫。
- 如何返回:服務提供者呼叫本地提供的服務之後,如何把返回值傳送給呼叫者。
1.2 PyTorch RPC 四大支柱
以下翻譯自官方文件 https://pytorch.org/docs/master/rpc.html。
分散式 RPC 框架通過一組原語提供了多機模型訓練機制以允許遠端通訊,以及一個更高階別的 API 來自動區分拆分到多臺機器上的模型。分散式 RPC 框架使遠端執行函式變得容易,支援引用遠端物件而無需複製真實資料,並提供 autograd 和優化器 API 以透明地向後執行和跨 RPC 邊界更新引數。這些功能可以分為四組 API。
- **遠端過程呼叫 (RPC) ** 支援使用給定的引數在指定的worker上執行函式並獲取返回值或建立對返回值的引用。有三個主要的 RPC API:
rpc_sync()
(同步)、rpc_async()
(非同步)和remote()
(非同步並返回對遠端返回值的引用)。如果使用者程式碼在沒有返回值的情況下無法繼續,請使用同步 API。否則,使用非同步 API 獲取 Future,並在呼叫者需要返回值時等待 Future。remote()
API 在需要遠端建立某些內容但從不需要將其獲取給呼叫者時很有用。想象一下driver程式設定引數伺服器和訓練器的情況。Driver 可以在引數伺服器上建立嵌入表,然後與訓練器共享嵌入表的引用,但其本身永遠不會在本地使用嵌入表。在這種情況下,rpc_sync()
和rpc_async()
已不再適用,因為他們總是意味著立即或在將來把返回值發給呼叫者。 - 遠端引用 (RRef)用作指向本地或遠端物件的分散式共享指標。它可以與其他 worker 共享,並且引用計數將被透明處理。每個 RRef 只有一個所有者,並且物件只存在於該所有者之中。持有 RRef 的非所有者worker 可以通過明確請求從所有者那裡獲取物件的副本。當 worker 需要訪問某個資料物件,但它本身既不是物件的建立者
remote()
函式的呼叫者也不是物件的所有者時,這很有用。分散式優化器就是此類用例的一個示例。 - Distributed Autograd將所有參與前向傳播 worker的本地 autograd 引擎縫合在一起,並在後向傳播期間自動聯絡他們以計算梯度。在進行前向傳遞如果需要跨越多臺機器時,這尤其有用,例如分散式模型並行訓練、引數伺服器訓練等。 有了這個特性,使用者程式碼不再需要擔心如何跨 RPC 邊界傳送梯度和應該以什麼順序啟動本地 autograd 引擎,如果前向傳遞中有巢狀和相互依賴的 RPC 呼叫,這可能會變得非常複雜。
- 分佈優化器的構造需要一個
Optimizer()
(例如,SGD()
,Adagrad()
等)和一個RRefs的引數列表。即,在每個不同的Ref所有者之上建立一個Optimizer()
例項,然後執行step()
相應更新引數。當使用者進行分散式前向和後向傳播時,引數和梯度將分散在多個 worker 中,因此需要對每個相關 worker 進行優化。Distributed Optimizer 將所有這些本地優化器合而為一,並提供了簡潔的建構函式和step()
API。
1.3 RRef
下面我們以 https://pytorch.org/docs/master/rpc/rref.html 為基準來學習遠端引用協議的基本概念和部分設計細節。
RRef 是遠端參考(Remote REFerence)的縮寫。 它是位於本地或遠端工作worker上物件的引用,並且透明地在內部進行引用計數。 從概念上講,它可以被視為一個分散式共享指標。 應用程式可以呼叫 remote()
建立 一個RRef。 每個 RRef 都被 remote()
的呼叫者(即所有者)所擁有,並且可以由多個使用者使用。 所有者儲存實際資料,並跟蹤全域性參考計數。 每個 RRef 可以由全域性RRefId
唯一標識,該全域性RRefId
在建立時由 remote()
呼叫者分配。
在所有者worker中,只有一個OwnerRRef
例項包含真實資料,而在使用者worker之中,可以根據需要包含任意數量的UserRRefs
,UserRRef
不儲存資料。當使用 RRP 時,所有者將使用全域性唯一的RRefId來獲取唯一的OwnerRRef例項。 在 rpc_sync()
, rpc_async()
或 remote()
呼叫中,所有者建立一個UserRRef
,並將其用作引數或返回值。所有者將被通知並且相應更新參考計數。 如果全域性沒有UserRRef
例項,並且所有者上也沒有對OwnerRRef
的引用,則OwnerRRef
及其資料將被刪除。
1.3.1 假設條件
RRef 協議的設計基於以下假設。
- 瞬態網路故障(Transient Network Failures):RRef 設計旨在通過重試訊息來應對瞬態網路故障。 RRef不能處理節點崩潰或永久性網路分割槽,當這些事件發生時,應用程式應該關閉所有worker,還原到先前的checkpoint,然後恢復訓練。
- 非冪等 UDF (Non-idempotent UDFs):我們假設提供給
rpc_sync()
,rpc_async()
或remote()
的使用者函式(UDF)不是冪等的,因此無法重試。 但是,內部 RRef 控制訊息是冪等且訊息失敗時可重試。 - 訊息傳遞無序(Out of Order Message Delivery):我們不會對一對節點之間的訊息傳遞順序做假設,因為傳送者和接收者都使用多個執行緒,所以無法保證首先處理哪個訊息。
接下來我們只是大致講解如何使用,具體大家可以參閱 https://pytorch.org/docs/master/rpc.html#distributed-rpc-framework。
1.3.2 同步呼叫
如下是同步呼叫API,該方法在 worker to
之上執行一個阻塞 RPC 呼叫來執行func
。RPC 訊息的傳送和接收與 Python 程式碼的執行並行。此方法是執行緒安全的。
torch.distributed.rpc.rpc_sync( to , func , args = None , kwargs = None , timeout = - 1.0 )
具體引數如下:
- to – 目標worker的name/rank/WorkerInfo。
- func (callable) – 一個可呼叫函式,例如 Python callables、內建運算子(例如add())和帶註釋的 TorchScript 函式。
- args –
func
呼叫的引數元組。 - kwargs –
func
呼叫關鍵字引數的字典。 - timeout – 用於此 RPC 的超時時間(以秒為單位)
返回值就是使用args
and kwargs
執行 func
的結果。
樣例:
確保 MASTER_ADDR
and MASTER_PORT
已經在兩個worker之上設定。
export MASTER_ADDR=localhost
export MASTER_PORT=5678
然後在兩個不同的程式中執行以下程式碼
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
1.3.2 非同步呼叫
如下是非同步呼叫API,該方法在 worker to
之上執行一個非阻塞 RPC 呼叫來執行func
。RPC 訊息的傳送和接收與 Python 程式碼的執行並行。此方法是執行緒安全的。該方法立刻返回一個可以被等待的Future
。
torch.distributed.rpc.rpc_async(to, func, args=None, kwargs=None, timeout=- 1.0)
具體引數如下:
- to – 目標worker的name/rank/
WorkerInfo
。 - func (callable) – 一個可呼叫函式,例如 Python callables、內建運算子(例如add())和帶註釋的 TorchScript 函式。
- args –
func
呼叫的引數元組。 - kwargs – 是
func
呼叫關鍵字引數的字典。 - timeout – 用於此 RPC 的超時時間(以秒為單位)
返回一個可等待的Future
物件。完成後,可以從 物件中檢索出func
的返回值。
樣例:
確保 MASTER_ADDR
and MASTER_PORT
已經在兩個worker之上設定。
>>> export MASTER_ADDR=localhost
>>> export MASTER_PORT=5678
然後在兩個不同的程式中執行以下程式碼
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
>>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait()
>>> rpc.shutdown()
>>> # On worker 1:
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
0x02 示例
我們接下來以 https://pytorch.org/docs/master/rpc/distributed_autograd.html 為基礎進行學習。
假設您有兩個節點和一個跨兩個節點分割槽的非常簡單的模型。這可以使用torch.distributed.rpc
如下實現。
分散式 autograd 背後的主要動機是在這種分散式模型上執行反向傳播loss
,我們已經計算並記錄了所有需要梯度的張量的梯度。
import torch
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
0x03 前向傳播期間的 Autograd 記錄
PyTorch 在前向傳播期間構建 autograd 圖,該圖用於執行後向傳播。有關更多詳細資訊,請參閱 autograd 如何編碼歷史記錄。
對於分散式 autograd,我們需要在前向傳播期間跟蹤所有 RPC,以確保正確執行後向傳播。為此,當執行 RPC 時候,我們把 send
和recv
functions 附加到autograd圖之上。
- 該
send
函式附加到 RPC 的發起源節點之上,其輸出邊指向 RPC 輸入張量的 autograd 函式。在向後傳播期間,send
函式的輸入是從目標接收的,是對應recv
函式的輸出。 - 該
recv
函式附加到 RPC 的接受目標節點之上,其輸入從某些運算子得到,這些運算子使用輸入張量在RPC接受目標上執行。在後向傳播期間,recv
函式的輸出梯度將被髮送到源節點之上,並且作為send
方法的輸入。 - 每
send-recv
對被分配一個全域性唯一的autograd_message_id
以唯一地標識該send-recv
對。這對於在向後傳播期間查詢遠端節點上的相應函式很有用。 - 對於RRef,每當我們呼叫
torch.distributed.rpc.RRef.to_here()
時,我們都為涉及的張量新增了一個適當的send-recv
對。
例如,這就是我們上面示例的 autograd 圖的樣子(為簡單起見,t5.sum() 被排除在外)。
我們可以看到,send方法在前向傳播中是傳送者,但是在反向傳播之中就是接受者。
0x04 分散式 Autograd 上下文
每個使用分散式 autograd 的前向和後向傳播都被分配了一個唯一的torch.distributed.autograd.context
,並且這個上下文具有一個全域性唯一的autograd_context_id
。如果有需要,在每個節點上都會建立上下文。
上下文的作用如下:
- 執行分散式反向傳播的多個節點可能會在同一個張量上累積梯度並且儲存在張量的
.grad
之上。在我們執行優化器之前,張量的.grad
可能累積了來自各種分散式反向傳播的梯度。這類似於把torch.autograd.backward()
在本地進行多次呼叫。為了提供一種把每個反向傳播梯度分離開的方法,在每個反向傳播過程裡,梯度將被累積在torch.distributed.autograd.context
之中。 - 在前向傳播期間,我們在上下文中儲存每個 autograd 傳播的
send
和recv
函式。這確保我們在 autograd 圖中儲存對適當節點的引用以使其保持活動狀態。除此之外,這也使得在向後傳播期間很容易查詢到對應的send
和recv
函式。 - 一般來說,我們也使用這個上下文來儲存每個分散式 autograd 傳播的一些後設資料。
從使用者的角度來看,autograd 上下文設定如下:
import torch.distributed.autograd as dist_autograd
with dist_autograd.context() as context_id:
loss = model.forward()
dist_autograd.backward(context_id, loss)
需要注意的是,模型的前向傳播必須在分散式autograd上下文管理器中呼叫,因為需要一個有效的上下文來確保:所有的send
和recv
方法被儲存起來,並且在所有參與節點之上執行後向傳播。
0x05 分散式反向傳播
在本節中,我們將概述在分散式反向傳播期間準確計算依賴關係所遇到的挑戰,並且也講述幾種如何執行分散式反向傳播的演算法(演算法內部有權衡)。
5.1 計算依賴關係
首先,考慮在單臺機器上執行以下程式碼
import torch
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = a + b
e = b * c
d.sum.().backward()
下圖就是上面程式碼對應的 autograd 圖。
作為反向傳播的一部分,autograd 引擎執行的第一步是計算 autograd 圖中每個節點的依賴項數量。這有助於 autograd 引擎知道圖中的節點何時準備好了可以執行。括號內為數字add(1)
和mul(0)
表示依賴關係的數量。如您所見,這意味著在向後傳播期間,add
節點需要 1 個輸入,mul
節點不需要任何輸入(換句話說,不需要執行)。本地 autograd 引擎通過從根節點(在本例中是d
)遍歷圖來計算這些依賴關係。
實際上,Autograd 圖中的某些節點可能不會在向後傳播中執行。這一事實對分散式 autograd 提出了挑戰。考慮這段使用 RPC 的程式碼。
import torch
import torch.distributed.rpc as rpc
a = torch.rand((3, 3), requires_grad=True)
b = torch.rand((3, 3), requires_grad=True)
c = torch.rand((3, 3), requires_grad=True)
d = rpc.rpc_sync("worker1", torch.add, args=(a, b))
e = rpc.rpc_sync("worker1", torch.mul, args=(b, c))
loss = d.sum()
上面程式碼的關聯 autograd 圖將是:
計算此分散式 autograd 圖的依賴項更具挑戰性,並且需要一些開銷(在計算或網路通訊方面)。
對於效能敏感的應用,我們可以通過假設每個send
和recv
函式都是反向傳播的有效成分來避免大量開銷(大多數應用不會執行未使用的 RPC)。這簡化了分散式 autograd 演算法並且效率更高,但代價是應用程式需要了解這些限制。這種演算法稱為FAST模式演算法,下面詳細介紹。
在一般情況下, 作為向後傳播的一部分,可能不需要每個send
和recv
函式都是有效的。為了解決這個問題,我們提出了一種SMART 模式演算法,此演算法將在後面的部分中描述。請注意,目前僅實現了FAST模式演算法。
5.2 FAST模式演算法
該演算法的關鍵假設是:當我們執行反向傳播時,每個send
函式的依賴為 1。換句話說,我們假設我們會從另一個節點通過 RPC 接收梯度。
演算法如下:
- 我們從具有反向傳播根的worker開始(所有根都必須是本地的)。
- 查詢當前Distributed Autograd Context 的所有
send
函式 。 - 從提供的根和我們檢索到的所有
send
函式開始,我們在本地計算依賴項 。 - 計算依賴項後,使用提供的根來啟動本地 autograd 引擎。
- 當 autograd 引擎執行該
recv
函式時,該recv
函式通過 RPC 將輸入梯度傳送到適當的worker。每個recv
函式都知道目標 worker id,因為它被記錄為前向傳播的一部分。通過autograd_context_id
和autograd_message_id
該recv
函式被髮送到遠端主機。 - 當遠端主機收到這個請求時,我們使用
autograd_context_id
和autograd_message_id
來查詢適當的send
函式。 - 如果這是worker第一次收到對給定
autograd_context_id
的請求,它將按照上面的第 1-3 點所述在本地計算依賴項。 - 然後將在第6點接受到的
send
方法插入佇列,以便在該worker的本地 autograd 引擎上執行。 - 最後,我們不是在 Tensor的
.grad
之上累積梯度,而是在每個Distributed Autograd Context之上分別累積梯度 。梯度儲存在Dict[Tensor, Tensor]
之中 ,Dict[Tensor, Tensor]
基本上是從 Tensor 到其關聯梯度的對映,並且可以使用 get_gradients() API檢索該對映 。
例如,分散式 autograd 的完整程式碼如下:
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
def my_add(t1, t2):
return torch.add(t1, t2)
# On worker 0:
# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
# Run the backward pass.
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
具有依賴關係的分散式 autograd 圖如下(為簡單起見,t5.sum() 被排除在外):
應用於上述示例的FAST 模式演算法如下:
- 在
Worker 0
上,我們從根loss
和send1
開始計算依賴關係。 結果,send1
對Worker 0
的依賴數為 1,mul
對Worker 0
的依賴數為 1。 - 現在,我們在
Worker 0
上啟動本地 autograd 引擎。 我們首先執行mul
函式,將其輸出作為t4
的梯度,累積儲存在 autograd 上下文中。 然後,我們執行recv2
,它將這些梯度傳送到Worker 1
。 - 由於這是
Worker 1
第一次知道有關此反向傳播的資訊,因此它將進行依賴關係計算,並且相應地標記send2
,add
和recv1
的依賴性。 - 接下來,在
Worker 1
的本地autograd
引擎上將send2
插入佇列,該引擎將依次執行add
和recv1
。 - 當執行
recv1
時,它將梯度傳送到Worker 0
。 - 由於
Worker 0
已經計算了此向後傳播的依賴性,因此它僅僅在本地將send1
插入佇列並且執行。 - 最後,
t1
,t2
和t4
的梯度會累積在分散式 Autograd 上下文中。
5.3 SMART模式演算法
該演算法的全部細節仍在研究中,但對於總體思路,您可以參考RFC中的分散式 Autograd 演算法智慧模式部分 。
0x06 分散式優化器
該DistributedOptimizer
操作如下:
- 獲取要優化的遠端引數(
RRef
)列表。這些引數也可以是包含在本地RRef
的本地引數。 - 將一個
Optimizer
類作為本地優化器,該優化器將在所有不同的RRef
擁有者之上執行。 - 分散式優化器在每個工作節點上建立一個本地
Optimizer
例項,並且對於每一個Optimizer
儲存一個RRef
。 - 當呼叫
torch.distributed.optim.DistributedOptimizer.step()
時,分散式優化器使用 RPC 在適當的遠端工作者上遠端執行所有本地優化器。必須為torch.distributed.optim.DistributedOptimizer.step()
提供一個分散式autogradcontext_id
。 本地優化器使用context_id
在相應上下文中儲存梯度。 - 如果多個併發分散式優化器正在更新一個 worker 上的同一批引數,這些更新將通過鎖來進行序列操作。
0x07 簡單的端到端示例
綜上所述,以下是一個使用分散式 autograd 和分散式優化器的簡單端到端示例。如果將程式碼放入名為“dist_autograd_simple.py”的檔案中,則可以使用以下命令執行 :MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py
import torch
import torch.multiprocessing as mp
import torch.distributed.autograd as dist_autograd
from torch.distributed import rpc
from torch import optim
from torch.distributed.optim import DistributedOptimizer
def random_tensor():
return torch.rand((3, 3), requires_grad=True)
def _run_process(rank, dst_rank, world_size):
name = "worker{}".format(rank)
dst_name = "worker{}".format(dst_rank)
# Initialize RPC.
rpc.init_rpc(
name=name,
rank=rank,
world_size=world_size
)
# Use a distributed autograd context.
with dist_autograd.context() as context_id:
# Forward pass (create references on remote nodes).
rref1 = rpc.remote(dst_name, random_tensor)
rref2 = rpc.remote(dst_name, random_tensor)
loss = rref1.to_here() + rref2.to_here()
# Backward pass (run distributed autograd).
dist_autograd.backward(context_id, [loss.sum()])
# Build DistributedOptimizer.
dist_optim = DistributedOptimizer(
optim.SGD,
[rref1, rref2],
lr=0.05,
)
# Run the distributed optimizer step.
dist_optim.step(context_id)
def run_process(rank, world_size):
dst_rank = (rank + 1) % world_size
_run_process(rank, dst_rank, world_size)
rpc.shutdown()
if __name__ == '__main__':
# Run world_size workers
world_size = 2
mp.spawn(run_process, args=(world_size,), nprocs=world_size)
0xFF 參考
https://pytorch.org/docs/master/rpc/distributed_autograd.html#distributed-autograd-design
https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework
https://pytorch.org/docs/master/rpc/rref.html
https://pytorch.org/docs/master/rpc.html#distributed-rpc-framework