[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法
0x00 摘要
前文中我們介紹了反向傳播引擎的動態邏輯,因為具體反向傳播演算法是在裝置執行緒中完成的,所以我們單獨用一章來講解。
本系列其他文章如下:
[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀
[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)
[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)
[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現
[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎
[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構
[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯
0x01 工作執行緒主體
thread_main是工作執行緒的主體函式,主要邏輯就是圍繞著 ReadyQueue 執行一個 while 迴圈,工作執行緒阻塞在 ReadyQueue -> pop 這裡,如果主執行緒或者其他執行緒插入了一個 NodeTask,則 pop 會返回取出一個 NodeTask,工作執行緒處理這個 NodeTask,完成後向計算的一個環節,如果有需要就繼續往某一ReadyQueue插入新的 NodeTask,驅動引擎繼續執行後向計算其他環節。
thread_main 從如下途徑被呼叫:
- CUDA, XLA 裝置的 autograd threads 會呼叫。
- CPU 之上的反向傳播主執行緒會呼叫。
- 前兩個case 進行可重入反向傳播,也會呼叫。
1.1 執行緒主體程式碼
工作執行緒的計算始於動態圖的GraphRoot函式,反向傳播就以 Node 的edge為紐帶,層層從前向後計算,直到來到了leaf節點,最終完成了反向計算,具體如下:
- local_graph_task表示我們從佇列中檢索的graph_task。外部graph_ 任務表示我們需要執行的可重入執行的總體 graph_任務。
- 從自己的ReadyQueue之中取出NodeTask例項,使用 local_graph_task 為引數來執行evaluate_function(反向傳播函式)。
- outstanding_tasks 自減 1。
- 如果本 local_graph_task 已經結束(可重入反向傳播會執行多個 GraphTask),即:
- 執行後續操作 exec_post_processing,然後使用 future_result_->markCompleted。
- 如果這個task是來自其它worker thread,即 worker_device != base_owner,則向那個worker thread的queue傳送一個dummy function task,讓那個工作執行緒也執行起來。
具體程式碼如下:
// thread_main is used by:
// 1). autograd threads for devices (i.e. CUDA, XLA)
// 2). the caller/owning thread of the backward call on CPU (sync mode)
// 3). Renetrant backward that invoked by either 1) or 2)
// The exit conditions are different for the above three cases.
// For 1), we are spinning on running the thread_main on device autograd
// threads throughout the Engine lifetime, thread_main will get
// terminated during Engine destruction by pushing shutdown tasks
// For 2), the owning thread of the backward call drives the thread_main
// synchronously until the graph_task of that owning thread is
// completed and exit the thread_main to continue executing the
// result of caller's code.
// For 3), the reentrant backward that invokes
// thread_main, either from 1) or 2), will not spin and will exit as
// long as graph_task is completed and notify the owning thread as
// needed.
auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
// When graph_task is nullptr, this is a long running thread that processes
// tasks (ex: device threads). When graph_task is non-null (ex: reentrant
// backwards, user thread), this function is expected to exit once that
// graph_task complete.
// local_ready_queue should already been initialized when we get into thread_main
while (graph_task == nullptr || !graph_task->future_result_->completed()) {
// local_graph_task represents the graph_task we retrieve from the queue.
// The outer graph_task represents the overall graph_task we need to execute
// for reentrant execution.
std::shared_ptr<GraphTask> local_graph_task;
{
// Scope this block of execution since NodeTask is not needed after this
// block and can be deallocated (release any references to grad tensors
// as part of inputs_).
NodeTask task = local_ready_queue->pop(); // 阻塞等待
// This will only work if the worker is running a non backward task
// TODO Needs to be fixed this to work in all cases
if (task.isShutdownTask_) {
break;
}
if (!(local_graph_task = task.base_.lock())) {
// GraphTask for function is no longer valid, skipping further
// execution.
continue;
}
if (task.fn_ && !local_graph_task->has_error_.load()) {
// 利用grad_mode_來配置AutoGradMode,整個反向計算期間的程式碼都靠GradMode::is_enabled()來判斷當前是否是要計算grad
AutoGradMode grad_mode(local_graph_task->grad_mode_);
try {
// The guard sets the thread_local current_graph_task on construction
// and restores it on exit. The current_graph_task variable helps
// queue_callback() to find the target GraphTask to append final
// callbacks.
GraphTaskGuard guard(local_graph_task);
NodeGuard ndguard(task.fn_);
// 執行後向計算
evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
} catch (std::exception& e) {
thread_on_exception(local_graph_task, task.fn_, e);
}
}
}
// Decrement the outstanding tasks.
--local_graph_task->outstanding_tasks_;
// Check if we've completed execution.
if (local_graph_task->completed()) { // 已經結束了,進行後續處理
local_graph_task->mark_as_completed_and_run_post_processing();
auto base_owner = local_graph_task->owner_; // 後續是需要在 GraphTask 的 owner_ 處理
// The current worker thread finish the graph_task, but the owning thread
// of the graph_task might be sleeping on pop() if it does not have work.
// So we need to send a dummy function task to the owning thread just to
// ensure that it's not sleeping, so that we can exit the thread_main.
// If it has work, it might see that graph_task->outstanding_tasks_ == 0
// before it gets to the task, but it's a no-op anyway.
//
// NB: This is not necessary if the current thread is the owning thread.
if (worker_device != base_owner) {
// Synchronize outstanding_tasks_ with queue mutex
std::atomic_thread_fence(std::memory_order_release);
// 獲取後續工作的queue
ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
}
}
}
}
1.2 使用 Ready Queue
上述程式碼之中,最後使用 ready_queue_by_index 獲取到後續工作對應的queue。
ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
如何獲取Ready Queue?具體策略是:
- 如果下一個 需要執行的裝置是 CPU,則選用cpu_ready_queue。
- 否則從device_ready_queues_選取一個GPU對應的 ReadyQueue。
程式碼如下:
auto Engine::ready_queue_by_index(std::shared_ptr<ReadyQueue> cpu_ready_queue, int device_index) -> std::shared_ptr<ReadyQueue> {
if (device_index == CPU_DEVICE) {
// return the cpu ready queue passed in
TORCH_INTERNAL_ASSERT(cpu_ready_queue);
return cpu_ready_queue;
} else {
// Static cast is ok here as the number of device should never overflow an int.
TORCH_INTERNAL_ASSERT(0 <= device_index && device_index < static_cast<int>(device_ready_queues_.size()));
// See Note [Allocating GPUs to autograd threads]
// NB: This function would become obsolete if we truly allocated a CPU thread
// per device, rather than colocate.
return device_ready_queues_.at(device_index);
}
}
邏輯如下:
+---------------------------------------------------------------------+
| Main Thread |
| |
| push(NodeTask)+--------------+ |
| | |
+---------------------------------------------------------------------+
|
|
v
+------+-----+
| |
| ReadyQueue |
| |
+------+-----+
|
|
|
+---------------------------------------------------------------------+
| Worker Thread 1 | |
| | |
| thread_main{ | |
| v |
| NodeTask task = local_ready_queue->pop() |
| |
| evaluate_function(task.fn_.get(),task.inputs_) |
| } |
+---------------------------------------------------------------------+
0x02 反向計算總體邏輯
evaluate_function 方法完成了反向計算的邏輯,總體邏輯如下:
- 準備工作:如果exec_info需要處理,則處理 captured_vars_。
- 反向計算:呼叫 call_function(graph_task, func, inputs),這是反向傳播中計算相關的核心邏輯:
- 呼叫pre hooks。
- 呼叫fn進行計算。
- 呼叫post hooks。
- 掃尾工作:
- 如果不需要keep graph,則fn.release_variables();
- 依據 call_function的輸出 outputs,進行計算 num_outputs = outputs.size(),得到 num_outputs的元素數量(該數量等同於當前fn的next_edge()返回的list中的元素數量)。
- 準備下一步工作,具體就是查詢後續需要計算的NodeTask,num_outputs 就是在這裡被用到。這部分比較複雜。
總體程式碼如下:
void Engine::evaluate_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func, // 導數計算方法
InputBuffer& inputs, // 當前Node的輸入梯度
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
// 進行準備工作
// If exec_info_ is not empty, we have to instrument the execution
auto& exec_info_ = graph_task->exec_info_;
if (!exec_info_.empty()) {
auto& fn_info = exec_info_.at(func); // 取出當前的進行處理
if (auto* capture_vec = fn_info.captures_.get()) {
// Lock mutex for writing to graph_task->captured_vars_.
std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (const auto& capture : *capture_vec) {
// captured_grad 就是臨時儲存下,每次node計算都會更新,最終輸出給呼叫者,相當於引用
// 1. captured_grad 引用了captured_vars_[capture.output_idx_],
auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
// 2. 給 captured_vars_[capture.output_idx_] 賦值 inputs[capture.input_idx_]
captured_grad = inputs[capture.input_idx_];
// 遍歷hooks,鏈式呼叫hook進行計算,captured_grad 不停的作為輸入和輸出在流水線中流淌
// 就是針對 captured_vars_[capture.output_idx_]不停的計算,最終結果還是在 captured_vars_[capture.output_idx_] 之中。
for (auto& hook : capture.hooks_) {
captured_grad = (*hook)(captured_grad);
}
}
}
if (!fn_info.needed_) {
// Skip execution if we don't need to execute the function.
return;
}
}
// Set the ThreadLocalState before calling the function.
// NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
// always saves ThreadLocalState without grad_mode.
at::ThreadLocalStateGuard tls_guard(graph_task->thread_locals_);
// Switches to a function's CUDA stream (if applicable) before calling it
const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};
// 進行反向計算
auto outputs = call_function(graph_task, func, inputs);
// 如果不需要保持計算圖,則本節點釋放變數
auto& fn = *func;
if (!graph_task->keep_graph_) {
fn.release_variables();
}
// 得到 num_outputs的元素數量(該數量等同於當前fn的next_edge()返回的list中的元素數量),後續遍歷本節點輸出時候會用到
int num_outputs = outputs.size();
if (num_outputs == 0) { // Note: doesn't acquire the mutex
// Records leaf stream (if applicable)
// See note "Streaming backwards"
if (opt_parent_stream) {
std::lock_guard<std::mutex> lock(graph_task->mutex_);
graph_task->leaf_streams.emplace(*opt_parent_stream);
}
return;
}
if (AnomalyMode::is_enabled()) {
AutoGradMode grad_mode(false);
for (int i = 0; i < num_outputs; ++i) {
auto& output = outputs[i];
at::OptionalDeviceGuard guard(device_of(output));
if (output.defined() && isnan(output).any().item<uint8_t>()) {
std::stringstream ss;
}
}
}
// 準備下一步工作
// Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and cpu_ready_queue_ below
std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (int i = 0; i < num_outputs; ++i) {
auto& output = outputs[i];
const auto& next = fn.next_edge(i); // next_edge是該node在前向傳播圖中的輸入,在反向傳播時候就是本節點的輸出,所以next就是下一個可能運算的節點
if (!next.is_valid()) continue;
// Check if the next function is ready to be computed
bool is_ready = false;
auto& dependencies = graph_task->dependencies_;
auto it = dependencies.find(next.function.get()); // 找到下一個節點的依賴
if (it == dependencies.end()) {
auto name = next.function->name();
throw std::runtime_error(std::string("dependency not found for ") + name);
} else if (--it->second == 0) {
dependencies.erase(it);
is_ready = true; // 下一個節點沒有入度了,那麼說明計算該節點梯度依賴的其他節點梯度都已經計算完成
}
// 要去 not_ready裡面看看,是否已經儲存了
auto& not_ready = graph_task->not_ready_;
auto not_ready_it = not_ready.find(next.function.get());
if (not_ready_it == not_ready.end()) {
// 下一個節點的梯度還沒有進行計算
// Skip functions that aren't supposed to be executed
// 跳過不需要計算的節點
if (!exec_info_.empty()) {
auto it = exec_info_.find(next.function.get());
if (it == exec_info_.end() || !it->second.should_execute()) {
continue;
}
}
// No buffers have been allocated for the function
InputBuffer input_buffer(next.function->num_inputs()); // 下一個節點前置梯度的buffer,就是下一個節點的輸入梯度
// Accumulates into buffer
// 下一個節點的輸入梯度就是當前節點的輸出,所以要拷貝過去
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
if (is_ready) {
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
// 既然依賴全部完成,就插入到ReadyQueue 之中
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
} else {
// 下一個節點的輸入依賴還沒有完成,就放到not_ready之中。
not_ready.emplace(next.function.get(), std::move(input_buffer));
}
} else {
// 如果下一個節點已經開始計算,但是沒有完成(就是依賴梯度還有),此時應該在not_ready之中
// The function already has a buffer
auto &input_buffer = not_ready_it->second;
// Accumulates into buffer
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
// Graph中每一個node(fn)的輸出是下一個node(fn)的輸入,下面4句程式碼來將前一個fn的輸出轉化為下一個fn的輸入
if (is_ready) {
// 如果此時已經沒有輸入依賴,就放入新的NodeTask,就是下一個需要計算梯度的NodeTask
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
//已經完成下一個節點前置梯度計算,從not_ready中移除相應的buffer
not_ready.erase(not_ready_it);
}
}
}
}
因為這部分程式碼十分複雜,我們逐一進行分析。
0x03 準備工作
首先我們看看準備工作,具體如下:
- 取出當前 Node 的 ExecInfo。
- 取出其 captures_,遍歷其中每一個 Capture。
- 遍歷Capture 的 hooks,鏈式呼叫hook進行計算。
- captured_grad 不停的作為輸入和輸出在流水線中流淌,針對
captured_vars_[capture.output_idx_]
陸續計算。 - 最終結果儲存在
captured_vars_[capture.output_idx_]
之中。
- captured_grad 不停的作為輸入和輸出在流水線中流淌,針對
程式碼中有一個細節,就是captured_grad 只是臨時儲存,每次node計算都會更新,最終輸出給呼叫者,相當於引用。
void Engine::evaluate_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func, // 導數計算方法
InputBuffer& inputs, // 當前Node的輸入梯度
const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
// 進行準備工作
// If exec_info_ is not empty, we have to instrument the execution
auto& exec_info_ = graph_task->exec_info_;
if (!exec_info_.empty()) {
auto& fn_info = exec_info_.at(func); // 取出當前的進行處理
if (auto* capture_vec = fn_info.captures_.get()) {
// Lock mutex for writing to graph_task->captured_vars_.
std::lock_guard<std::mutex> lock(graph_task->mutex_);
for (const auto& capture : *capture_vec) {
// captured_grad 就是臨時儲存下,每次node計算都會更新,最終輸出給呼叫者,相當於引用
// 1. captured_grad 引用了captured_vars_[capture.output_idx_],
auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
// 2. 給 captured_vars_[capture.output_idx_] 賦值 inputs[capture.input_idx_]
captured_grad = inputs[capture.input_idx_];
// 遍歷hooks,鏈式呼叫hook進行計算,captured_grad 不停的作為輸入和輸出在流水線中流淌
// 就是針對 captured_vars_[capture.output_idx_]不停的計算,最終結果還是在 captured_vars_[capture.output_idx_] 之中。
for (auto& hook : capture.hooks_) {
captured_grad = (*hook)(captured_grad);
}
}
}
if (!fn_info.needed_) {
// Skip execution if we don't need to execute the function.
return;
}
}
0x04 核心邏輯
call_function是反向傳播中計算相關的核心邏輯。
- 呼叫註冊在本 node上的pre_hooks;
- 呼叫node本身,比如MeanBackward0、MulBackward0等。
- 輸入是
InputBuffer::variables(std::move(inputBuffer))
,一組Variable的例項。當動態圖剛開始進行反向計算時,引擎首先執行的是圖的根節點——graph_root,它的輸入是task.inputs——InputBuffer(0)。 - 呼叫的是fn的apply(),apply是多型實現,針對不同的operation會dispatch到operation對應的apply實現上。
- 輸出也是一組Variable的例項 outputs = fn(std::move(inputs_copy)),outputs 要作為下一個fn的輸入。
- 輸入是
- 呼叫註冊在node上的post hooks。
- 返回當前節點對應的導數,這是一個variable_list。
具體程式碼如下:
static variable_list call_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputBuffer) {
CheckpointValidGuard cpvguard(graph_task);
auto& fn = *func;
auto inputs =
call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
if (!graph_task->keep_graph_) {
fn.will_release_variables();
}
const auto has_post_hooks = !fn.post_hooks().empty();
variable_list outputs;
if (has_post_hooks) {
// In functions/accumulate_grad.cpp, there is some logic to check the
// conditions under which the incoming gradient can be stolen directly
// (which elides a deep copy) instead of cloned. One of these conditions
// is that the incoming gradient's refcount must be 1 (nothing else is
// referencing the same data). Stashing inputs_copy here bumps the
// refcount, so if post hooks are employed, it's actually still ok for
// accumulate_grad.cpp to steal the gradient if the refcount is 2.
//
// "new_grad.use_count() <= 1 + !post_hooks().empty()" in
// accumulate_grad.cpp accounts for this, but also creates a silent
// dependency between engine.cpp (ie, this particular engine
// implementation) and accumulate_grad.cpp.
//
// If you change the logic here, make sure it's compatible with
// accumulate_grad.cpp.
auto inputs_copy = inputs;
outputs = fn(std::move(inputs_copy));
} else {
outputs = fn(std::move(inputs));
}
validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
std::ostringstream ss;
return ss.str();
});
if(has_post_hooks){
return call_post_hooks(fn, std::move(outputs), inputs);
}
return outputs;
}
0x05 準備下一步工作
這部分是反向傳播的複雜之處。
現在呼叫 call_function,得到了後向傳播的輸出,記錄到了 outputs 之中。
auto outputs = call_function(graph_task, func, inputs);
所以,後半部分就是從 outputs 之中尋找後續可以計算的Node。
總體思路就是:遍歷後向傳播的輸出節點(就是該節點在前向計算圖中的入邊連線的節點),逐一衡量輸出節點。遍歷迴圈中分為兩段程式碼,對於每一個輸出節點做如下操作:
- 第一段是依據依賴排查這個節點,得到這個節點是否就緒。核心就是看看這個輸出節點在GraphTask的dependencies的計數是否降為0。
- 如果是0,就說明這個節點就緒了,說明這個node不會被未來的計算所依賴了。
- 如果非0,就說明這個節點有多個輸入,即,被多個node連線,而且有的輸入還沒有計算完成梯度。
- 第二段是依據是否就緒來處理這個節點,比如放入哪一個queue。
5.1 依據依賴排查節點
第一段程式碼功能是依據依賴關係來 排查節點,得到這個節點是否就緒,具體如下:
-
假定某一個節點是 output,我們得到對應的邊,遍歷輸出邊。
-
每次把一個輸出邊記錄為 next,func 是 NodeTask 之中的函式。
-
利用 dependencies_ 的資訊,next 是否可以計算。dependencies_ 裡面記錄的是圖中所有節點的依賴。
-
從 dependencies_ 之中找到 next 對應的依賴數目,把依賴數目減一(通常因為有多個 input)。
- 如果
--it->second == 0
,說明該前置節點計算梯度所依賴的其他節點梯度都已經完成計算。則- 把該前置節點對應的資訊GraphTask中移除,即從GraphTask的dependencies中移除(後續也會從GraphTask的 not_ready 成員變數之中移除)。
- 將is_ready 置為true,後續會依據這個 is_ready 的數值進行操作。
- 如果
-
從 not_ready_ 之中得到 next 對應的輸入buffer(後續程式碼就是對此進行操作);
-
std::unordered_map<Node*, InputBuffer> not_ready_;
-
程式碼如下:
-
for (int i = 0; i < num_outputs; ++i) { // 遍歷輸出節點,逐一衡量
auto& output = outputs[i];
const auto& next = fn.next_edge(i); // 獲得一個輸出節點
if (!next.is_valid()) continue;
// Check if the next function is ready to be computed
bool is_ready = false;
auto& dependencies = graph_task->dependencies_; // 拿到GraphTask的依賴關係
auto it = dependencies.find(next.function.get()); // 找到輸出節點的依賴項
if (it == dependencies.end()) {
auto name = next.function->name(); // 沒找到
throw std::runtime_error(std::string("dependency not found for ") + name);
} else if (--it->second == 0) {
dependencies.erase(it); // 找到了,並且已經計算完畢
is_ready = true;
}
auto& not_ready = graph_task->not_ready_;
auto not_ready_it = not_ready.find(next.function.get()); // 找到輸入buffer
現在已經找到了某一個輸出節點,也知道其是否計算完畢(依據有沒有依賴項),也拿到了其存在"未就緒佇列"的輸入buffer(如果存在的話)。
5.2 處理這個節點
第二段是依據是否就緒來處理這個節點,比如放入哪一個queue,是就緒佇列?還是未就緒佇列?核心是:
- 如果就緒,就放到該節點對應的 ReadyQueue 去處理。
- 如果沒有就緒,就新建立一個NodeTask放到 GraphTask的 not_ready 等待後續處理。需要注意的是,這個新的NodeTask 是在 worker thread 之中建立的。
- 如何找到 ReadyQueue?需要看這個 Node 節點的 input_buffer.device() ,即,這個新 NodeTask 應該傳送到 input_buffer.device() 那個 device 對應的 ReadyQueue。
我們具體看看如何依據 is_ready 的數值來對 not_ready 進行操作。
- 如果在 未就緒佇列 not_ready 之中 沒有找到 next_edge 對應的元素,則:
- 如果 exec_info_ 不為空,則在 exec_info_ 之中查詢 next_edge 對應的元素,如果有元素且註明了不需要執行,就跳到for迴圈的下一個。
- 用 next_edge 的流,inut_nr 等資訊構建一個 input_buffer。
- 如果 is_ready 是 True,就用 本 GraphTask,next.function,input_buffer構建一個NodeTask,放入 ReadyQueue(利用 input_buffer.device() 來得到對應的 queue)。這就要喚醒下一個 worker 執行緒。
- 如果 is_ready 是 False,這通常表明這個node有多個輸入(被更多的node連線,使用num_inputs()可以獲得數量),也說明此次處理的是這個node的第一個輸入,後續還需要使用這個 next_edge,所以這個 next_edge 需要被放到 not_ready 之中。則把 next.function,input_buffer 放入到 not_ready 之中,這個input_buffer 就是 next_edge 後續執行時候需要的各種輸入。
- 如果在 未就緒佇列 not_ready 之中找到了 next_edge 對應的元素,則:
- 拿出來該元素對應的 input_buffer,把資訊累積到 input_buffer 之中。此次累積的是該節點的其他輸入。 input_buffer.add(next.input_nr, std::move(output), opt_parent_stream, opt_next_stream) 完成了累積操作,next.input_nr 就表明當前的node是反向傳播中要流向的node(next)的第幾個輸入。
- 如果is_ready 是 True,就用 本 GraphTask,next.function,input_buffer構建一個NodeTask,放入 ReadyQueue。這就要喚醒下一個 worker 執行緒。
- 從 not_ready 之中移除此元素,就是從 GraphTask 的依賴關係之中去除。
程式碼如下:
if (not_ready_it == not_ready.end()) {
// Skip functions that aren't supposed to be executed
if (!exec_info_.empty()) {
auto it = exec_info_.find(next.function.get());
if (it == exec_info_.end() || !it->second.should_execute()) {
continue;
}
}
// No buffers have been allocated for the function
InputBuffer input_buffer(next.function->num_inputs());
// Accumulates into buffer
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
if (is_ready) {
// 找出了下一個Node的queue
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
queue->push( //
NodeTask(graph_task, next.function, std::move(input_buffer)));
} else {
not_ready.emplace(next.function.get(), std::move(input_buffer));
}
} else {
// The function already has a buffer
auto &input_buffer = not_ready_it->second;
// Accumulates into buffer
const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
input_buffer.add(next.input_nr,
std::move(output),
opt_parent_stream,
opt_next_stream);
if (is_ready) {
// 找出了下一個Node的queue
auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
queue->push(
NodeTask(graph_task, next.function, std::move(input_buffer)));
not_ready.erase(not_ready_it);
}
}
具體邏輯圖如下:
- func 指向了目前正在進行反向計算的 Node。
- func 呼叫自己的 apply 方法進行計算,得出了 outputs,假設有3個輸出,遍歷,我們選擇第三個為 output。
- func 的邊是 next_edges_ 成員變數,遍歷,我們選擇第三個邊為next。
- 用 next 和 GraphTask 的 dependencies_ 來判斷 next 是不是就緒。
- 如果就緒,把 output 構建一個 input_buffer,然後生成一個 NodeTask,插入到對應的 ReadyQuieue。
- 如果沒就緒,把 output 構建一個 input_buffer,和 next 一起放入 GraphTask 的 not_ready_,後續會使用。
1 +---------------+
func +--> | Node | +---> ...
| | |
| | |
| apply() +------> outputs +------> ... 2
| | |
| | |
| | | +--------------+
| | +---> output +--> | input_buffer +--+
| | +--------------+ |
| | |
| | |
| | | 5
| | |
| | |
| | +----> ... |
| | | +---------+
| | | | |
| next_edges_+---> +----> ... 3 | |
| | | | |
| | | | |
| | | 5 v |
| | +----> next +------>+ YES | +------------+
+---------------+ | +---> push(NodeTask) +-----> | ReadyQueue |
| 4 | | +------------+
| | |
+---------------+ +--> Ready? +-+ |
| GraphTask | | | 6 |
| | | | NO | 6
| | | +----> next.function |
| dependencies_+--> map<Node*, int> +-->+ + |
| | | |
| | | |
| | 6 v v
| not_ready_ +---------------------------------------------> map<Node*, InputBuffer>
| |
+---------------+
手機如下:
0x06 掃尾操作
在 thread_main 之中,如果本task已經結束,即做後續操作,具體程式碼如下。
auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
// 忽略前面程式碼
// Check if we've completed execution.
if (local_graph_task->completed()) { // 判斷是否結束
// 如果結束了,就進行後續操作
local_graph_task->mark_as_completed_and_run_post_processing();
auto base_owner = local_graph_task->owner_;
// The current worker thread finish the graph_task, but the owning thread
// of the graph_task might be sleeping on pop() if it does not have work.
// So we need to send a dummy function task to the owning thread just to
// ensure that it's not sleeping, so that we can exit the thread_main.
// If it has work, it might see that graph_task->outstanding_tasks_ == 0
// before it gets to the task, but it's a no-op anyway.
//
// NB: This is not necessary if the current thread is the owning thread.
if (worker_device != base_owner) {
// Synchronize outstanding_tasks_ with queue mutex
std::atomic_thread_fence(std::memory_order_release);
ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
}
}
我們接下來分析這些掃尾工作。注意,這裡是 thread_main 之中的掃尾工作。
6.1 判斷結束
以下程式碼用來判斷本 GraphTask是否結束,其實就是 ReadyQueue 之中是否還有待執行的 NodeTask。
outstanding_tasks_ 是待處理 NodeTask的數量,用來判斷該GrapTask是否還需要執行,其數值總是先加再減,如果數目為0,則說明任務結束了。
- 當 GraphTask 被建立出來時候,此數值為0。
- 如果有一個NodeTask被送入到 ReadyQueue,則outstanding_tasks_ 增加 1。
- 如果在工作執行緒作執行一次 evaluate_function(task)後,outstanding_tasks的值減 1。
- 如果這個數量不為0,則此GraphTask依然需要執行。
bool GraphTask::completed() {
// outstanding_tasks在evaluate_function中可能會被改變
return outstanding_tasks_.load() == 0 ||
(exit_on_error_ && has_error_.load());
}
6.2 後續&通知
mark_as_completed_and_run_post_processing 就是進行後續處理。
執行後續操作 exec_post_processing,然後使用 future_result_->markCompleted 通知主執行緒。
void GraphTask::mark_as_completed_and_run_post_processing() {
// Allow only one thread one attempt to process this logic.
if (future_completed_.exchange(true)) {
// Future is already marked complete, or being marked as such.
// In case the marking complete is only in progress, we add a
// wait() to guarantee the future is marked complete on exit.
future_result_->wait();
return;
}
try {
// Run post processing, before marking the future as complete.
// Drop lock prior to completing, to avoid holding across callbacks.
std::unique_lock<std::mutex> lock(mutex_);
exec_post_processing(); // 進行後續操作
std::vector<Variable> vars = std::move(captured_vars_);
// Need to unlock before we call markCompleted to avoid holding locks
// when the callbacks are called.
lock.unlock();
future_result_->markCompleted(std::move(vars)); // 通知主執行緒
} catch (std::exception& e) {
future_result_->setErrorIfNeeded(std::current_exception());
}
}
6.2.1 後續操作
後續操作,如果之前有註冊了 callback,則進行呼叫。也會進行流同步。
void GraphTask::exec_post_processing() {
if (!not_ready_.empty()) {
throw std::runtime_error("could not compute gradients for some functions");
}
// set the thread_local current_graph_task_ as more callbacks can be installed
// by existing final callbacks.
GraphTaskGuard guard(shared_from_this());
// Lock mutex during each iteration for accessing final_callbacks.size()
// Unlocking is necessary, because the callback can register
// more callbacks (or they can be registered from other threads
// while it's waiting.
std::unique_lock<std::mutex> cb_lock(final_callbacks_lock_);
// WARNING: Don't use a range-for loop here because more callbacks may be
// added in between callback calls, so iterators may become invalidated.
for (size_t i = 0; i < final_callbacks_.size(); ++i) {
cb_lock.unlock();
final_callbacks_[i]();
cb_lock.lock();
}
// Syncs leaf streams with default streams (if necessary)
// See note "Streaming backwards"
for (const auto& leaf_stream : leaf_streams) {
const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
const auto default_stream = guard.getDefaultStream(leaf_stream.device());
if (leaf_stream != default_stream) {
auto event = c10::Event{c10::DeviceType::CUDA};
event.record(leaf_stream);
default_stream.wait(event);
}
}
}
6.2.2 通知主執行緒
之前在 execute 之中會用 fut->wait() 來等待任務完成。下面我們省略了部分程式碼。
auto Engine::execute(const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
const edge_list& outputs) -> variable_list {
// Queue the root
if (skip_dummy_node) {
execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
} else {
execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list()));
}
auto& fut = graph_task->future_result_;
fut->wait();
return fut->value().toTensorVector();
}
在 mark_as_completed_and_run_post_processing 會用如下程式碼來通知主執行緒。
future_result_->markCompleted(std::move(vars)); // 通知主執行緒
6.3 通知其他執行緒
如果這個task是來自其它work thread,即 worker_device != base_owner,則向那個worker thread的queue傳送一個dummy function task,讓那個工作執行緒也執行起來。
local_graph_task
表示我們從佇列中檢索的 graph_task
。外部graph_
任務表示我們需要執行的可重入執行的總體graph_任務。
在 thread_main 之中,有一個 work around。就是:當前工作執行緒完成 graph_task,但此時,擁有graph_task的執行緒可能正在pop()上等待休眠。因此,我們需要向所屬執行緒傳送一個仿造的函式任務,以喚醒它,這樣我們可以退出thread_main。
這種情況發生在可重入反向傳播的情形。
// If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
// backward call from that device.
graph_task->owner_ = worker_device;
具體程式碼如下:
// Check if we've completed execution.
if (local_graph_task->completed()) {
local_graph_task->mark_as_completed_and_run_post_processing();
auto base_owner = local_graph_task->owner_; // 當前裝置
if (worker_device != base_owner) {
// 不是同一個裝置
// Synchronize outstanding_tasks_ with queue mutex
std::atomic_thread_fence(std::memory_order_release);
ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
->push(NodeTask(local_graph_task, nullptr, InputBuffer(0))); // dummy task
}
}
其他執行緒當收到了 dummy task 之後,不會處理,因為 function 是 nullptr,然後就呼叫 local_ready_queue->pop() 繼續從自己的queue 中讀取下一個 task。
具體如下:
- 主執行緒等待。
- 如果工作執行緒發現GraphTask 已經結束,就通知主執行緒。
- 如果需要喚醒其他執行緒,就向該執行緒對應的 queue 插入 NodeTask。
- 對應執行緒取出 NodeTask 進行執行。
+------------------------------------------------+
| Worker Thread 1 |
| |
| thread_main{ |
| |
| mark_as_completed_and_run_post_processing |
2 markCompleted() | { |
+-------------------+ |
| | } |
| | |
+---------------+ | | push(NodeTask) +-----+ |
| Main Thread | | | | |
| | | | } | |
| | | | | |
| | | +------------------------------------------------+
| | | |
| | | 3 |
| | v v
| | +-------+-------+
| | 1 +----------------+ | |
| | wait() | | | ReadyQueue |
| +------------> | future_result_ | | |
| | | | +-------+-------+
| | +----------------+ |
| | |
| | 4 | pop(NodeTask)
| | |
| | v
| | +--------+---------------------+
| | | Worker Thread 2 |
| | | |
| | | |
+---------------+ | |
| |
| |
+------------------------------+
至此,後向傳播已經分析完畢,從下一篇開始,我們正式進入 PyTorch 分散式訓練。
0xFF 參考
https://www.zhihu.com/column/gemfield
pytorch筆記(計算圖+autograd)-Node(1)
PyTorch Internals 5:Autograd的實現
A GENTLE INTRODUCTION TO TORCH.AUTOGRAD