[原始碼解析] PyTorch 分散式 Autograd (4) ---- 如何切入引擎
0x00 摘要
上文我們看到了AutogradMetadata,DistAutogradContainer 和 DistAutogradContext 等一系列基礎類。我們知道了分散式autograd如何基於RPC進行傳遞,如何在節點之間互動,節點如何區分維護這些Session。本文繼續分析,主要目的是看看反向傳播如何切入到引擎之中。
PyTorch分散式其他文章如下:
[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀
[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)
[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)
[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現
[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎
[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構
[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯
[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法
[原始碼解析] PyTorch 分散式(1)------歷史和概述
[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)
[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)
[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念
[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用
[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store
[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組
[原始碼解析] PyTorch 分散式(8) -------- DistributedDataParallel之論文篇
[原始碼解析] PyTorch 分散式(9) ----- DistributedDataParallel 之初始化
[原始碼解析] PyTorch 分散式(10)------DistributedDataParallel 之 Reducer靜態架構
[原始碼解析] PyTorch 分散式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作
[原始碼解析] PyTorch 分散式(12) ----- DistributedDataParallel 之 前向傳播
[原始碼解析] PyTorch 分散式(13) ----- DistributedDataParallel 之 反向傳播
[原始碼解析] PyTorch 分散式 Autograd (1) ---- 設計
[原始碼解析] PyTorch 分散式 Autograd (2) ---- RPC基礎
[原始碼解析] PyTorch 分散式 Autograd (3) ---- 上下文相關
為了更好的說明,本文程式碼會依據具體情況來進行相應精簡。
0x01 前文回憶
我們回憶一下前面幾篇文章的內容。
首先,對於分散式 autograd,我們需要在前向傳播期間跟蹤所有 RPC,以確保正確執行後向傳播。為此,當執行 RPC 時候,我們把 send
和recv
functions 附加到autograd圖之上。
- 該
send
函式附加到 RPC 的發起源節點之上,其輸出邊指向 RPC 輸入張量的 autograd 函式。在向後傳播期間,send
函式的輸入是從目標接收的,是對應recv
函式的輸出。 - 該
recv
函式附加到 RPC 的接受目標節點之上,其輸入從某些運算子得到,這些運算子使用輸入張量在RPC接受目標上執行。在後向傳播期間,recv
函式的輸出梯度將被髮送到源節點之上,並且作為send
方法的輸入。 - 每
send-recv
對被分配一個全域性唯一的autograd_message_id
以唯一地標識該send-recv
對。這對於在向後傳播期間查詢遠端節點上的相應函式很有用。 - 對於RRef,每當我們呼叫
torch.distributed.rpc.RRef.to_here()
時,我們都為涉及的張量新增了一個適當的send-recv
對。
其次,在前向傳播的具體程式碼之中,我們在上下文中儲存每個 autograd 傳播的send
和recv
函式。這確保我們在 autograd 圖中儲存對適當節點的引用以使其保持活動狀態。除此之外,這也使得在向後傳播期間很容易查詢到對應的send
和recv
函式。
再次,以下是 torch/csrc/distributed/rpc/message.h 之中的部分訊息定義:
// Messages with autograd info
FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,
// Messages to propagate gradients on the backward pass.
BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,
在前文,我們看到了 FORWARD_AUTOGRAD_REQ 在前向傳播之中如何呼叫,假設如下程式碼:rpc.rpc_sync("worker1", torch.add, args=(t1, t2)),其呼叫序列是:
- rpc_sync 呼叫 _invoke_rpc。
- _invoke_rpc 呼叫 _invoke_rpc_builtin。
- 然後呼叫到 pyRpcBuiltin,繼而呼叫到 sendMessageWithAutograd。
- sendMessageWithAutograd 內部會構建 FORWARD_AUTOGRAD_REQ訊息,最後使用RPC 傳送。
至此,關於整體流程,我們就有了幾個疑問:
- 在反向計算圖的起始位置,如何發起反向傳播,怎麼傳遞給反向傳播的下一個環節?
- 在反向傳播的內部環節,BACKWARD_AUTOGRAD_REQ 是何時呼叫?recv 操作是何時被呼叫? 在上下文中,recvAutogradFunctions_ 是在哪裡設定的?
- 以上兩個環節分別如何進入分散式autograd引擎?
我們接下來就圍繞這些疑問進行分析,核心就是如何進入 dist.autograd 引擎。
0x02 計算圖
我們首先從計算圖來通過幾個示例來看看。
2.1 普通示例
首先看看普通計算,這個是 dist.auto 官方圖例的本地版本。可以看到是由 AddBackward0,AccumulateGrad 和 MulBackward0 等組成了計算圖。
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
t3 = t1 + t2
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
next_functions = t5.grad_fn.next_functions
具體對應如下圖:
2.2 分散式示例
接下來看看分散式的例子,這個例子就是官方設計中圖例大致對應的程式碼,我們把 torch.mul(t3, t4) 命名為 t5,加入了 loss。
def worker0():
# On worker 0:
# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", torch.add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
# Run the backward pass.
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
print(loss)
在分散式之下,t3 是異地執行。
- t5 對應的是 mul,t5.grad_fn 是 <MulBackward0 object at 0x7fbf18d297b8>。
- t3.grad_fn 是 <CppFunction object at 0x7fbf18d11a20>,就是說,recv 對應的就是 CppFunction 。
- loss 是 tensor(5.5680, grad_fn=
)。 - 其餘的都是 None。
我們把設計圖例再展示出來,上面示例程式碼就是下圖的左側 worker 0,t3 實際就是執行在 worker 1,大家可以看到分散式上下文中的一些特點。
2.3 分散式註釋版
為了更好的說明,我們列印了一些log作為註釋。
def _verify_send(send_function):
print(send_function.name())
next_funcs = send_function.next_functions
print(next_funcs[0][0].name())
print(next_funcs[1][0].name())
def _verify_recv(recv_function):
print(recv_function.name())
next_funcs = recv_function.next_functions
print(len(next_funcs))
def worker0():
# On worker 0:
# Setup the autograd context. Computations that take
# part in the distributed backward pass must be within
# the distributed autograd context manager.
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
#t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
t3 = rpc.rpc_sync("worker1", torch.add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
print("--- send ---")
ctx = dist_autograd._retrieve_context(context_id)
send_functions = ctx._send_functions()
_verify_send(list(send_functions.values())[0])
print("--- loss ---")
print(loss)
mul_func = loss.grad_fn.next_functions[0][0]
print(mul_func.name())
next_funcs = mul_func.next_functions
print(next_funcs[0][0].name())
print(next_funcs[1][0].name())
print("---- recv ----")
recv_functions = ctx._recv_functions()
_verify_recv(list(recv_functions.values())[0])
# Run the backward pass.
dist_autograd.backward(context_id, [loss])
# Retrieve the gradients from the context.
dist_autograd.get_gradients(context_id)
列印結果是:
--- send ---
torch::distributed::autograd::SendRpcBackward
torch::autograd::AccumulateGrad
torch::autograd::AccumulateGrad
--- loss ---
tensor(3.5197, grad_fn=<SumBackward0>)
MulBackward0
torch::distributed::autograd::RecvRpcBackward
torch::autograd::AccumulateGrad
---- recv ----
torch::distributed::autograd::RecvRpcBackward
加上分散式相關運算元之後,圖例如下:
0x03 反向傳播
我們接下來要看看如何進入dist autograd 引擎,結合我們圖例,就是:
- worker 0 如何主動發起反向傳播,然後進入分散式引擎?
- woker 0 在內部如何發起對 worker 1 的反向傳播請求?
- worker 1 如何被動接受反向傳播訊息,然後進入分散式引擎?
3.1 發起反向傳播
我們找一找如何發起反向傳播,按照從下往上的順序進行。這裡也有兩種:
- 一種是主動發起,比如上圖之中 worker 0 的 loss 之上主動呼叫backward 方法。
- 一種是內部隱式發起,比如上圖的 worker 0 之中的 t3 如何通過 recv 告訴 worker 1,你應該啟動反向傳播了。
3.1.1 外部主動發起
3.1.1.1 示例
我們從上往下看分散式 autograd 的 backward 如何主動呼叫,比如在示例之中會顯示呼叫。
def worker0():
# On worker 0:
with dist_autograd.context() as context_id:
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
# Perform some computation remotely.
t3 = rpc.rpc_sync("worker1", torch.add, args=(t1, t2))
# Perform some computation locally based on remote result.
t4 = torch.rand((3, 3), requires_grad=True)
t5 = torch.mul(t3, t4)
# Compute some loss.
loss = t5.sum()
# Run the backward pass.
dist_autograd.backward(context_id, [loss]) // 這裡會呼叫
3.1.1.2 C++世界
在 torch/_C/_distributed_autograd.pyi
之中我們可以看到如下注釋:
# This module is defined in torch/csrc/distributed/autograd/init.cpp
因此我們去torch/csrc/distributed/autograd/init.cpp檔案中看看。
省略了部分程式碼,這裡能看到生成了上下文,定義了 backward,get_gradients等等。
PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) {
auto autograd_module =
THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd"));
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
auto m = torch_C_m.def_submodule("_distributed_autograd", "distributed autograd bindings");
auto module = py::handle(m).cast<py::module>();
auto distAutogradContext =
shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
.def(
"_context_id",
&DistAutogradContext::contextId,
py::call_guard<py::gil_scoped_release>())
.def(
"_recv_functions",
[](const DistAutogradContext& ctx) {
std::map<int64_t, py::object> funcs;
for (const auto& map_entry : ctx.recvFunctions()) {
funcs.emplace(
map_entry.first,
py::reinterpret_steal<py::object>(
torch::autograd::functionToPyObject(
map_entry.second)));
}
return funcs;
})
.def(
"_send_functions",
[](const ContextPtr& ctx) {
std::map<int64_t, py::object> funcs;
for (const auto& map_entry : ctx->sendFunctions()) {
funcs.emplace(
map_entry.first,
py::reinterpret_steal<py::object>(
torch::autograd::functionToPyObject(
map_entry.second)));
}
return funcs;
})
.def("_known_worker_ids", &DistAutogradContext::getKnownWorkerIds);
module.def(
"_new_context",
[]() -> const ContextPtr {
return DistAutogradContainer::getInstance().newContext();
},
py::return_value_policy::reference);
py::options options;
options.disable_function_signatures();
module.def(
"backward",
backward,
py::arg("contextId"),
py::arg("roots"),
py::arg("retain_graph") = false,
py::call_guard<py::gil_scoped_release>());
module.def(
"get_gradients",
[](int64_t contextId) -> py::dict {
const auto& autogradContext =
DistAutogradContainer::getInstance().retrieveContext(contextId);
return torch::jit::toPyObject(IValue(autogradContext->getGradients()));
},
py::arg("context_id"));
Py_RETURN_TRUE;
}
} // namespace
具體 backward 定義在 torch/csrc/distributed/autograd/autograd.cpp。
void backward(
int64_t context_id,
const variable_list& roots,
bool retain_graph) {
RECORD_FUNCTION(
kDistAutogradBackwardProfilingKey, std::vector<c10::IValue>());
try {
DistEngine::getInstance().execute(context_id, roots, retain_graph);
} catch (std::exception& e) {
// FIXME: crashes if exception type is not RuntimeError
throw std::runtime_error(e.what());
}
}
可以看到,最終會呼叫到 DistEngine::getInstance().execute(context_id, roots, retain_graph) 完成反向傳播。這就進入了引擎。
3.1.2 內部隱式發起
因為是隱式發起,所以程式碼比較隱蔽,我們這次採用從下至上的方式來剝絲抽繭。我們知道,如果節點之間要求反向傳播,會傳送BACKWARD_AUTOGRAD_REQ,所以我們從 BACKWARD_AUTOGRAD_REQ 開始發起尋找。
3.1.2.1 BACKWARD_AUTOGRAD_REQ
在 torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp 之中 PropagateGradientsReq::toMessageImpl 會呼叫到 BACKWARD_AUTOGRAD_REQ。
Message PropagateGradientsReq::toMessageImpl() && {
std::vector<at::IValue> ivalues;
// Add all the grad tensors.
for (const auto& grad : grads_) {
ivalues.emplace_back(grad);
}
// Now add autograd metadata.
ivalues.emplace_back(autogradMetadata_.autogradContextId);
ivalues.emplace_back(autogradMetadata_.autogradMessageId);
// Add retain graph.
ivalues.emplace_back(retainGraph_);
// Now pickle using JIT pickler.
std::vector<torch::Tensor> tensorTable;
std::vector<char> payload =
jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
return Message(
std::move(payload),
std::move(tensorTable),
MessageType::BACKWARD_AUTOGRAD_REQ); // 這裡會用到
}
3.1.2.2 PropagateGradientsReq
繼續找誰發出來的 BACKWARD_AUTOGRAD_REQ,就是誰呼叫到了 toMessageImpl?原來在 torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp 這裡構建了 PropagateGradientsReq,會使用 toMessage 來構建一個訊息。即,RecvRpcBackward 的呼叫會傳送 BACKWARD_AUTOGRAD_REQ。
variable_list RecvRpcBackward::apply(variable_list&& grads) { // 呼叫Node
std::vector<Variable> outputGrads;
for (size_t i = 0; i < grads.size(); i++) {
const auto& grad = grads[i];
if (grad.defined()) {
outputGrads.emplace_back(grad);
} else {
// Put in zeros for a tensor with no grad.
outputGrads.emplace_back(input_metadata(i).zeros_like());
}
}
auto sharedContext = autogradContext_.lock();
// Send the gradients over the wire and record the future in the autograd
// context.
PropagateGradientsReq gradCall( // 這裡構建了 PropagateGradientsReq
autogradMetadata_,
outputGrads,
sharedContext->retrieveGraphTask()->keep_graph_);
// Send the gradients over to the appropriate node.
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
auto jitFuture = rpcAgent->send( // 傳送出去,就是給後向傳播過程的下一個節點
rpcAgent->getWorkerInfo(fromWorkerId_),
std::move(gradCall).toMessage(), // 這裡呼叫了PropagateGradientsReq::toMessageImpl
rpc::kUnsetRpcTimeout,
deviceMap_);
// Record the future in the context.
sharedContext->addOutstandingRpc(jitFuture);
// 'recv' function sends the gradients over the wire using RPC, it doesn't
// need to return anything for any downstream autograd function.
return variable_list();
}
所以我們知道,在 RecvRpcBackward 的執行時候,會傳送 BACKWARD_AUTOGRAD_REQ,傳送給下一個節點。具體哪裡呼叫 RecvRpcBackward?我們會在下一篇 DistEngine 之中介紹。
此時具體如下,對應就是 worker 0 的 t3 給 worker 1 傳送 BACKWARD_AUTOGRAD_REQ 訊息。
+
worker 0 | worker 1
|
|
RecvRpcBackward PropagateGradientsReq |
+ + |
| | |
| | |
| | |
v | |
| |
apply() | |
+ | |
| v |
| |
| +------------------------------> toMessageImpl |
| + |
| | |
| Message(BACKWARD_AUTOGRAD_REQ) | |
| <----------------------------------------+ |
| |
| |
v |
|
rpcAgent+>send(Message) +-------------------------------------------->
+ BACKWARD_AUTOGRAD_REQ |
| |
| |
v |
+
對應示例圖就是:
3.2 接受反向傳播
我們接下來看看接收方如何處理反向傳播,我們再次回到 worker 1,就是圖上的 send 節點如何接受反向傳播訊息。
3.2.1 接受訊息
在生成 TensorPipeAgent 時候,把 RequestCallbackImpl 配置為回撥函式。這是 agent 的統一響應函式。前面關於代理接收邏輯時候,我們也提到了,會進入 RequestCallbackNoPython::processRpc 函式。其中可以看到有對 BACKWARD_AUTOGRAD_REQ 的處理邏輯。
這種是 RPC 的正常流程。
void RequestCallbackNoPython::processRpc(
RpcCommandBase& rpc,
const MessageType& messageType,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture,
std::shared_ptr<LazyStreamContext> ctx) const {
switch (messageType) {
case MessageType::BACKWARD_AUTOGRAD_REQ: {
processBackwardAutogradReq(rpc, messageId, responseFuture); // 這裡呼叫
return;
};
3.2.2 處理訊息
在 processBackwardAutogradReq 之中會:
- 獲取 DistAutogradContainer。
- 獲取 上下文,該上下文是之前在前向傳播過程之中建立的,從前文可知,本圖例之中,worker 0 和 worker 1之中每個 autograd 傳播都共享同一個上下文 context id。
- 通過傳送方的 context id,從上下文之中獲取到對應的 SendRpcBackward。這裡我們看到了上下文是如何使用。
- 使用 sendFunction 作為引數,呼叫 executeSendFunctionAsync 進行引擎處理。
void RequestCallbackNoPython::processBackwardAutogradReq(
RpcCommandBase& rpc,
const int64_t messageId,
const c10::intrusive_ptr<JitFuture>& responseFuture) const {
auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
// Retrieve the appropriate autograd context.
auto autogradContext = DistAutogradContainer::getInstance().retrieveContext(
autogradMetadata.autogradContextId); // 得到傳送者的context id
// Lookup the appropriate 'send' function to enqueue.
std::shared_ptr<SendRpcBackward> sendFunction = // 依據傳送者context id和訊息id得到sendFunction
autogradContext->retrieveSendFunction(autogradMetadata.autogradMessageId);
// Attach the gradients to the send function.
sendFunction->setGrads(gradientsCall.getGrads()); // 設定梯度
// Now execute the autograd graph using the "distributed engine."
auto execFuture = DistEngine::getInstance().executeSendFunctionAsync( // 呼叫引擎
autogradContext, sendFunction, gradientsCall.retainGraph());
// Our response is satisfied when the rpcs come back.
execFuture->addCallback([responseFuture, messageId](JitFuture& execFuture) {
if (!execFuture.hasError()) {
Message m = std::move(PropagateGradientsResp()).toMessage();
m.setId(messageId);
responseFuture->markCompleted(
IValue(c10::make_intrusive<Message>(std::move(m))));
} else {
responseFuture->setError(execFuture.exception_ptr());
}
});
}
在 worker 1 的 DistEngine::executeSendFunctionAsync 內部,會進行輾轉處理,最終傳送 BACKWARD_AUTOGRAD_REQ 到其反向傳播的下游,所以我們繼續在示例圖之上修改擴充,增加一個 BACKWARD_AUTOGRAD_REQ。
3.3 總結
我們可以看到有兩個途徑進入 dist autograd 引擎,啟動反向傳播:
- 一個是示例程式碼顯式主動呼叫 backward,進而呼叫到 DistEngine::getInstance().execute,就是 worker 0。
- 一個是被動呼叫 DistEngine::getInstance().executeSendFunctionAsync,就是 worker 1(當然,worker 0 的 send 也對應了一個被動呼叫)。
現在從上至下/自下而上兩種查詢反向傳播的發起源頭,都歸結到了 DistEngine,所以我們下一篇就介紹 DistEngine。
0xFF 參考
[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯
[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構