[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

羅西的思考發表於2021-11-01

[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

0x00 摘要

前文中我們介紹了反向傳播引擎的動態邏輯,因為具體反向傳播演算法是在裝置執行緒中完成的,所以我們單獨用一章來講解。

img

本系列其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀

[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎

[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

0x01 工作執行緒主體

thread_main是工作執行緒的主體函式,主要邏輯就是圍繞著 ReadyQueue 執行一個 while 迴圈,工作執行緒阻塞在 ReadyQueue -> pop 這裡,如果主執行緒或者其他執行緒插入了一個 NodeTask,則 pop 會返回取出一個 NodeTask,工作執行緒處理這個 NodeTask,完成後向計算的一個環節,如果有需要就繼續往某一ReadyQueue插入新的 NodeTask,驅動引擎繼續執行後向計算其他環節。

thread_main 從如下途徑被呼叫:

  1. CUDA, XLA 裝置的 autograd threads 會呼叫。
  2. CPU 之上的反向傳播主執行緒會呼叫。
  3. 前兩個case 進行可重入反向傳播,也會呼叫。

1.1 執行緒主體程式碼

工作執行緒的計算始於動態圖的GraphRoot函式,反向傳播就以 Node 的edge為紐帶,層層從前向後計算,直到來到了leaf節點,最終完成了反向計算,具體如下:

  • local_graph_task表示我們從佇列中檢索的graph_task。外部graph_ 任務表示我們需要執行的可重入執行的總體 graph_任務。
  • 從自己的ReadyQueue之中取出NodeTask例項,使用 local_graph_task 為引數來執行evaluate_function(反向傳播函式)。
  • outstanding_tasks 自減 1。
  • 如果本 local_graph_task 已經結束(可重入反向傳播會執行多個 GraphTask),即:
    • 執行後續操作 exec_post_processing,然後使用 future_result_->markCompleted。
    • 如果這個task是來自其它worker thread,即 worker_device != base_owner,則向那個worker thread的queue傳送一個dummy function task,讓那個工作執行緒也執行起來。

具體程式碼如下:

// thread_main is used by:
// 1). autograd threads for devices (i.e. CUDA, XLA)
// 2). the caller/owning thread of the backward call on CPU (sync mode)
// 3). Renetrant backward that invoked by either 1) or 2)
// The exit conditions are different for the above three cases.
// For 1), we are spinning on running the thread_main on device autograd
//         threads throughout the Engine lifetime, thread_main will get
//         terminated during Engine destruction by pushing shutdown tasks
// For 2), the owning thread of the backward call drives the thread_main
//         synchronously until the graph_task of that owning thread is
//         completed and exit the thread_main to continue executing the
//         result of caller's code.
// For 3), the reentrant backward that invokes
//         thread_main, either from 1) or 2), will not spin and will exit as
//         long as graph_task is completed and notify the owning thread as
//         needed.
auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
  // When graph_task is nullptr, this is a long running thread that processes
  // tasks (ex: device threads). When graph_task is non-null (ex: reentrant
  // backwards, user thread), this function is expected to exit once that
  // graph_task complete.

  // local_ready_queue should already been initialized when we get into thread_main
  while (graph_task == nullptr || !graph_task->future_result_->completed()) {
    // local_graph_task represents the graph_task we retrieve from the queue.
    // The outer graph_task represents the overall graph_task we need to execute
    // for reentrant execution.
    std::shared_ptr<GraphTask> local_graph_task;
    {
      // Scope this block of execution since NodeTask is not needed after this
      // block and can be deallocated (release any references to grad tensors
      // as part of inputs_).
      NodeTask task = local_ready_queue->pop(); // 阻塞等待
      // This will only work if the worker is running a non backward task
      // TODO Needs to be fixed this to work in all cases
      if (task.isShutdownTask_) {
        break;
      }

      if (!(local_graph_task = task.base_.lock())) {
        // GraphTask for function is no longer valid, skipping further
        // execution.
        continue;
      }

      if (task.fn_ && !local_graph_task->has_error_.load()) {
       // 利用grad_mode_來配置AutoGradMode,整個反向計算期間的程式碼都靠GradMode::is_enabled()來判斷當前是否是要計算grad  
        AutoGradMode grad_mode(local_graph_task->grad_mode_);
        try {
          // The guard sets the thread_local current_graph_task on construction
          // and restores it on exit. The current_graph_task variable helps
          // queue_callback() to find the target GraphTask to append final
          // callbacks.
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
          // 執行後向計算
          evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);
        }
      }
    }

    // Decrement the outstanding tasks.
    --local_graph_task->outstanding_tasks_;

    // Check if we've completed execution.
    if (local_graph_task->completed()) { // 已經結束了,進行後續處理
      local_graph_task->mark_as_completed_and_run_post_processing();

      auto base_owner = local_graph_task->owner_; // 後續是需要在 GraphTask 的 owner_ 處理
      // The current worker thread finish the graph_task, but the owning thread
      // of the graph_task might be sleeping on pop() if it does not have work.
      // So we need to send a dummy function task to the owning thread just to
      // ensure that it's not sleeping, so that we can exit the thread_main.
      // If it has work, it might see that graph_task->outstanding_tasks_ == 0
      // before it gets to the task, but it's a no-op anyway.
      //
      // NB: This is not necessary if the current thread is the owning thread.
      if (worker_device != base_owner) {
        // Synchronize outstanding_tasks_ with queue mutex
        std::atomic_thread_fence(std::memory_order_release);
        // 獲取後續工作的queue
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
      }
    }
  }
}

1.2 使用 Ready Queue

上述程式碼之中,最後使用 ready_queue_by_index 獲取到後續工作對應的queue。

ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
    ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));

如何獲取Ready Queue?具體策略是:

  • 如果下一個 需要執行的裝置是 CPU,則選用cpu_ready_queue。
  • 否則從device_ready_queues_選取一個GPU對應的 ReadyQueue。

程式碼如下:

auto Engine::ready_queue_by_index(std::shared_ptr<ReadyQueue> cpu_ready_queue, int device_index) -> std::shared_ptr<ReadyQueue> {
  if (device_index == CPU_DEVICE) {
    // return the cpu ready queue passed in
    TORCH_INTERNAL_ASSERT(cpu_ready_queue);
    return cpu_ready_queue;
  } else {
    // Static cast is ok here as the number of device should never overflow an int.
    TORCH_INTERNAL_ASSERT(0 <= device_index && device_index < static_cast<int>(device_ready_queues_.size()));
    // See Note [Allocating GPUs to autograd threads]
    // NB: This function would become obsolete if we truly allocated a CPU thread
    // per device, rather than colocate.
    return device_ready_queues_.at(device_index);
  }
}

邏輯如下:

+---------------------------------------------------------------------+
|  Main Thread                                                        |
|                                                                     |
|            push(NodeTask)+--------------+                           |
|                                         |                           |
+---------------------------------------------------------------------+
                                          |
                                          |
                                          v
                                   +------+-----+
                                   |            |
                                   | ReadyQueue |
                                   |            |
                                   +------+-----+
                                          |
                                          |
                                          |
+---------------------------------------------------------------------+
| Worker Thread 1                         |                           |
|                                         |                           |
|  thread_main{                           |                           |
|                                         v                           |
|     NodeTask task = local_ready_queue->pop()                        |
|                                                                     |
|     evaluate_function(task.fn_.get(),task.inputs_)                  |
|  }                                                                  |
+---------------------------------------------------------------------+

0x02 反向計算總體邏輯

evaluate_function 方法完成了反向計算的邏輯,總體邏輯如下:

  • 準備工作:如果exec_info需要處理,則處理 captured_vars_。
  • 反向計算:呼叫 call_function(graph_task, func, inputs),這是反向傳播中計算相關的核心邏輯:
    • 呼叫pre hooks。
    • 呼叫fn進行計算。
    • 呼叫post hooks。
  • 掃尾工作:
    • 如果不需要keep graph,則fn.release_variables();
    • 依據 call_function的輸出 outputs,進行計算 num_outputs = outputs.size(),得到 num_outputs的元素數量(該數量等同於當前fn的next_edge()返回的list中的元素數量)。
  • 準備下一步工作,具體就是查詢後續需要計算的NodeTask,num_outputs 就是在這裡被用到。這部分比較複雜。

總體程式碼如下:

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func, // 導數計算方法
    InputBuffer& inputs, // 當前Node的輸入梯度
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
    
  // 進行準備工作  
  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!exec_info_.empty()) {
    auto& fn_info = exec_info_.at(func); // 取出當前的進行處理
    if (auto* capture_vec = fn_info.captures_.get()) {
      // Lock mutex for writing to graph_task->captured_vars_.
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      for (const auto& capture : *capture_vec) {
        // captured_grad 就是臨時儲存下,每次node計算都會更新,最終輸出給呼叫者,相當於引用
        // 1. captured_grad 引用了captured_vars_[capture.output_idx_],
        auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
        // 2. 給 captured_vars_[capture.output_idx_] 賦值 inputs[capture.input_idx_]
        captured_grad = inputs[capture.input_idx_];
        // 遍歷hooks,鏈式呼叫hook進行計算,captured_grad 不停的作為輸入和輸出在流水線中流淌
        // 就是針對 captured_vars_[capture.output_idx_]不停的計算,最終結果還是在 captured_vars_[capture.output_idx_] 之中。
        for (auto& hook : capture.hooks_) {
          captured_grad = (*hook)(captured_grad);
        }
      }
    }
    if (!fn_info.needed_) {
      // Skip execution if we don't need to execute the function.
      return;
    }
  }

  // Set the ThreadLocalState before calling the function.
  // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
  // always saves ThreadLocalState without grad_mode.
  at::ThreadLocalStateGuard tls_guard(graph_task->thread_locals_);

  // Switches to a function's CUDA stream (if applicable) before calling it
  const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
  c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};

  // 進行反向計算
  auto outputs = call_function(graph_task, func, inputs);

  // 如果不需要保持計算圖,則本節點釋放變數
  auto& fn = *func;
  if (!graph_task->keep_graph_) {
    fn.release_variables();
  }

  // 得到 num_outputs的元素數量(該數量等同於當前fn的next_edge()返回的list中的元素數量),後續遍歷本節點輸出時候會用到
  int num_outputs = outputs.size();
  if (num_outputs == 0) { // Note: doesn't acquire the mutex
    // Records leaf stream (if applicable)
    // See note "Streaming backwards"
    if (opt_parent_stream) {
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      graph_task->leaf_streams.emplace(*opt_parent_stream);
    }
    return;
  }

  if (AnomalyMode::is_enabled()) {
    AutoGradMode grad_mode(false);
    for (int i = 0; i < num_outputs; ++i) {
      auto& output = outputs[i];
      at::OptionalDeviceGuard guard(device_of(output));
      if (output.defined() && isnan(output).any().item<uint8_t>()) {
        std::stringstream ss;
      }
    }
  }

  // 準備下一步工作
  // Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and cpu_ready_queue_ below
  std::lock_guard<std::mutex> lock(graph_task->mutex_);
  for (int i = 0; i < num_outputs; ++i) {
    auto& output = outputs[i];
    const auto& next = fn.next_edge(i); // next_edge是該node在前向傳播圖中的輸入,在反向傳播時候就是本節點的輸出,所以next就是下一個可能運算的節點

    if (!next.is_valid()) continue;

    // Check if the next function is ready to be computed
    bool is_ready = false;
    auto& dependencies = graph_task->dependencies_;
    auto it = dependencies.find(next.function.get()); // 找到下一個節點的依賴

    if (it == dependencies.end()) {
      auto name = next.function->name();
      throw std::runtime_error(std::string("dependency not found for ") + name);
    } else if (--it->second == 0) {
      dependencies.erase(it);
      is_ready = true; // 下一個節點沒有入度了,那麼說明計算該節點梯度依賴的其他節點梯度都已經計算完成
    }

    // 要去 not_ready裡面看看,是否已經儲存了
    auto& not_ready = graph_task->not_ready_;
    auto not_ready_it = not_ready.find(next.function.get());
    if (not_ready_it == not_ready.end()) {
      // 下一個節點的梯度還沒有進行計算
      // Skip functions that aren't supposed to be executed
      // 跳過不需要計算的節點
      if (!exec_info_.empty()) {
        auto it = exec_info_.find(next.function.get());
        if (it == exec_info_.end() || !it->second.should_execute()) {
          continue;
        }
      }
      // No buffers have been allocated for the function
      InputBuffer input_buffer(next.function->num_inputs()); // 下一個節點前置梯度的buffer,就是下一個節點的輸入梯度

      // Accumulates into buffer
      // 下一個節點的輸入梯度就是當前節點的輸出,所以要拷貝過去
      const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);

      if (is_ready) {
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        // 既然依賴全部完成,就插入到ReadyQueue 之中
        queue->push(
            NodeTask(graph_task, next.function, std::move(input_buffer)));
      } else {
        // 下一個節點的輸入依賴還沒有完成,就放到not_ready之中。
        not_ready.emplace(next.function.get(), std::move(input_buffer));
      }
    } else {
      // 如果下一個節點已經開始計算,但是沒有完成(就是依賴梯度還有),此時應該在not_ready之中
      // The function already has a buffer
      auto &input_buffer = not_ready_it->second;

      // Accumulates into buffer
      const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);
        
      // Graph中每一個node(fn)的輸出是下一個node(fn)的輸入,下面4句程式碼來將前一個fn的輸出轉化為下一個fn的輸入  
      if (is_ready) {
        // 如果此時已經沒有輸入依賴,就放入新的NodeTask,就是下一個需要計算梯度的NodeTask
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push(
            NodeTask(graph_task, next.function, std::move(input_buffer)));
        //已經完成下一個節點前置梯度計算,從not_ready中移除相應的buffer
        not_ready.erase(not_ready_it);
      }
    }
  }
}

因為這部分程式碼十分複雜,我們逐一進行分析。

0x03 準備工作

首先我們看看準備工作,具體如下:

  • 取出當前 Node 的 ExecInfo。
  • 取出其 captures_,遍歷其中每一個 Capture。
  • 遍歷Capture 的 hooks,鏈式呼叫hook進行計算。
    • captured_grad 不停的作為輸入和輸出在流水線中流淌,針對 captured_vars_[capture.output_idx_]陸續計算。
    • 最終結果儲存在 captured_vars_[capture.output_idx_] 之中。

程式碼中有一個細節,就是captured_grad 只是臨時儲存,每次node計算都會更新,最終輸出給呼叫者,相當於引用

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func, // 導數計算方法
    InputBuffer& inputs, // 當前Node的輸入梯度
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
    
  // 進行準備工作  
  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!exec_info_.empty()) {
    auto& fn_info = exec_info_.at(func); // 取出當前的進行處理
    if (auto* capture_vec = fn_info.captures_.get()) {
      // Lock mutex for writing to graph_task->captured_vars_.
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      for (const auto& capture : *capture_vec) {
        // captured_grad 就是臨時儲存下,每次node計算都會更新,最終輸出給呼叫者,相當於引用
        // 1. captured_grad 引用了captured_vars_[capture.output_idx_],
        auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
        // 2. 給 captured_vars_[capture.output_idx_] 賦值 inputs[capture.input_idx_]
        captured_grad = inputs[capture.input_idx_];
        // 遍歷hooks,鏈式呼叫hook進行計算,captured_grad 不停的作為輸入和輸出在流水線中流淌
        // 就是針對 captured_vars_[capture.output_idx_]不停的計算,最終結果還是在 captured_vars_[capture.output_idx_] 之中。
        for (auto& hook : capture.hooks_) {
          captured_grad = (*hook)(captured_grad);
        }
      }
    }
    if (!fn_info.needed_) {
      // Skip execution if we don't need to execute the function.
      return;
    }
  }

0x04 核心邏輯

call_function是反向傳播中計算相關的核心邏輯。

  • 呼叫註冊在本 node上的pre_hooks;
  • 呼叫node本身,比如MeanBackward0、MulBackward0等。
    • 輸入是InputBuffer::variables(std::move(inputBuffer)),一組Variable的例項。當動態圖剛開始進行反向計算時,引擎首先執行的是圖的根節點——graph_root,它的輸入是task.inputs——InputBuffer(0)。
    • 呼叫的是fn的apply(),apply是多型實現,針對不同的operation會dispatch到operation對應的apply實現上。
    • 輸出也是一組Variable的例項 outputs = fn(std::move(inputs_copy)),outputs 要作為下一個fn的輸入。
  • 呼叫註冊在node上的post hooks。
  • 返回當前節點對應的導數,這是一個variable_list。

具體程式碼如下:

static variable_list call_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputBuffer) {
  CheckpointValidGuard cpvguard(graph_task);
  auto& fn = *func;
  auto inputs =
      call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));

  if (!graph_task->keep_graph_) {
    fn.will_release_variables();
  }

  const auto has_post_hooks = !fn.post_hooks().empty();
  variable_list outputs;

  if (has_post_hooks) {
    // In functions/accumulate_grad.cpp, there is some logic to check the
    // conditions under which the incoming gradient can be stolen directly
    // (which elides a deep copy) instead of cloned. One of these conditions
    // is that the incoming gradient's refcount must be 1 (nothing else is
    // referencing the same data).  Stashing inputs_copy here bumps the
    // refcount, so if post hooks are employed, it's actually still ok for
    // accumulate_grad.cpp to steal the gradient if the refcount is 2.
    //
    // "new_grad.use_count() <= 1 + !post_hooks().empty()" in
    // accumulate_grad.cpp accounts for this, but also creates a silent
    // dependency between engine.cpp (ie, this particular engine
    // implementation) and accumulate_grad.cpp.
    //
    // If you change the logic here, make sure it's compatible with
    // accumulate_grad.cpp.
    auto inputs_copy = inputs;
    outputs = fn(std::move(inputs_copy));
  } else {
    outputs = fn(std::move(inputs));
  }

  validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
    std::ostringstream ss;
    return ss.str();
  });

  if(has_post_hooks){
    return call_post_hooks(fn, std::move(outputs), inputs);
  }
  return outputs;
}

0x05 準備下一步工作

這部分是反向傳播的複雜之處。

現在呼叫 call_function,得到了後向傳播的輸出,記錄到了 outputs 之中。

auto outputs = call_function(graph_task, func, inputs);

所以,後半部分就是從 outputs 之中尋找後續可以計算的Node

總體思路就是:遍歷後向傳播的輸出節點(就是該節點在前向計算圖中的入邊連線的節點),逐一衡量輸出節點。遍歷迴圈中分為兩段程式碼,對於每一個輸出節點做如下操作:

  • 第一段是依據依賴排查這個節點,得到這個節點是否就緒。核心就是看看這個輸出節點在GraphTask的dependencies的計數是否降為0
    • 如果是0,就說明這個節點就緒了,說明這個node不會被未來的計算所依賴了。
    • 如果非0,就說明這個節點有多個輸入,即,被多個node連線,而且有的輸入還沒有計算完成梯度。
  • 第二段是依據是否就緒來處理這個節點,比如放入哪一個queue

5.1 依據依賴排查節點

第一段程式碼功能是依據依賴關係來 排查節點,得到這個節點是否就緒,具體如下:

  • 假定某一個節點是 output,我們得到對應的邊,遍歷輸出邊。

    • 每次把一個輸出邊記錄為 next,func 是 NodeTask 之中的函式。

    • 利用 dependencies_ 的資訊,next 是否可以計算。dependencies_ 裡面記錄的是圖中所有節點的依賴。

    • 從 dependencies_ 之中找到 next 對應的依賴數目,把依賴數目減一(通常因為有多個 input)。

      • 如果--it->second == 0,說明該前置節點計算梯度所依賴的其他節點梯度都已經完成計算。則
        • 把該前置節點對應的資訊GraphTask中移除,即從GraphTask的dependencies中移除(後續也會從GraphTask的 not_ready 成員變數之中移除)。
        • 將is_ready 置為true,後續會依據這個 is_ready 的數值進行操作。
    • 從 not_ready_ 之中得到 next 對應的輸入buffer(後續程式碼就是對此進行操作);

      • std::unordered_map<Node*, InputBuffer> not_ready_;
        

    程式碼如下:

  for (int i = 0; i < num_outputs; ++i) { // 遍歷輸出節點,逐一衡量
    auto& output = outputs[i];
    const auto& next = fn.next_edge(i); // 獲得一個輸出節點
      
    if (!next.is_valid()) continue;

    // Check if the next function is ready to be computed
    bool is_ready = false;
    auto& dependencies = graph_task->dependencies_; // 拿到GraphTask的依賴關係
    auto it = dependencies.find(next.function.get()); // 找到輸出節點的依賴項

    if (it == dependencies.end()) {
      auto name = next.function->name(); // 沒找到
      throw std::runtime_error(std::string("dependency not found for ") + name);
    } else if (--it->second == 0) {
      dependencies.erase(it);  // 找到了,並且已經計算完畢
      is_ready = true;
    }

    auto& not_ready = graph_task->not_ready_; 
    auto not_ready_it = not_ready.find(next.function.get()); // 找到輸入buffer     

現在已經找到了某一個輸出節點,也知道其是否計算完畢(依據有沒有依賴項),也拿到了其存在"未就緒佇列"的輸入buffer(如果存在的話)。

5.2 處理這個節點

第二段是依據是否就緒來處理這個節點,比如放入哪一個queue,是就緒佇列?還是未就緒佇列?核心是:

  • 如果就緒,就放到該節點對應的 ReadyQueue 去處理。
  • 如果沒有就緒,就新建立一個NodeTask放到 GraphTask的 not_ready 等待後續處理。需要注意的是,這個新的NodeTask 是在 worker thread 之中建立的。
  • 如何找到 ReadyQueue?需要看這個 Node 節點的 input_buffer.device() ,即,這個新 NodeTask 應該傳送到 input_buffer.device() 那個 device 對應的 ReadyQueue。

我們具體看看如何依據 is_ready 的數值來對 not_ready 進行操作。

  • 如果在 未就緒佇列 not_ready 之中 沒有找到 next_edge 對應的元素,則:
    • 如果 exec_info_ 不為空,則在 exec_info_ 之中查詢 next_edge 對應的元素,如果有元素且註明了不需要執行,就跳到for迴圈的下一個。
    • 用 next_edge 的流,inut_nr 等資訊構建一個 input_buffer。
    • 如果 is_ready 是 True,就用 本 GraphTask,next.function,input_buffer構建一個NodeTask,放入 ReadyQueue(利用 input_buffer.device() 來得到對應的 queue)。這就要喚醒下一個 worker 執行緒
    • 如果 is_ready 是 False,這通常表明這個node有多個輸入(被更多的node連線,使用num_inputs()可以獲得數量),也說明此次處理的是這個node的第一個輸入,後續還需要使用這個 next_edge,所以這個 next_edge 需要被放到 not_ready 之中。則把 next.function,input_buffer 放入到 not_ready 之中,這個input_buffer 就是 next_edge 後續執行時候需要的各種輸入。
  • 如果在 未就緒佇列 not_ready 之中找到了 next_edge 對應的元素,則:
    • 拿出來該元素對應的 input_buffer,把資訊累積到 input_buffer 之中。此次累積的是該節點的其他輸入。 input_buffer.add(next.input_nr, std::move(output), opt_parent_stream, opt_next_stream) 完成了累積操作,next.input_nr 就表明當前的node是反向傳播中要流向的node(next)的第幾個輸入。
    • 如果is_ready 是 True,就用 本 GraphTask,next.function,input_buffer構建一個NodeTask,放入 ReadyQueue。這就要喚醒下一個 worker 執行緒
    • 從 not_ready 之中移除此元素,就是從 GraphTask 的依賴關係之中去除。

程式碼如下:

    if (not_ready_it == not_ready.end()) {
      // Skip functions that aren't supposed to be executed
      if (!exec_info_.empty()) {
        auto it = exec_info_.find(next.function.get());
        if (it == exec_info_.end() || !it->second.should_execute()) {
          continue;
        }
      }
      // No buffers have been allocated for the function
      InputBuffer input_buffer(next.function->num_inputs());

      // Accumulates into buffer
      const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);

      if (is_ready) {
        // 找出了下一個Node的queue
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push( //
            NodeTask(graph_task, next.function, std::move(input_buffer)));
      } else {
        not_ready.emplace(next.function.get(), std::move(input_buffer));
      }
    } else {
      // The function already has a buffer
      auto &input_buffer = not_ready_it->second;

      // Accumulates into buffer
      const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);
      if (is_ready) {
        // 找出了下一個Node的queue
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push(
            NodeTask(graph_task, next.function, std::move(input_buffer)));
        not_ready.erase(not_ready_it);
      }
    }

具體邏輯圖如下:

  1. func 指向了目前正在進行反向計算的 Node。
  2. func 呼叫自己的 apply 方法進行計算,得出了 outputs,假設有3個輸出,遍歷,我們選擇第三個為 output。
  3. func 的邊是 next_edges_ 成員變數,遍歷,我們選擇第三個邊為next。
  4. 用 next 和 GraphTask 的 dependencies_ 來判斷 next 是不是就緒。
  5. 如果就緒,把 output 構建一個 input_buffer,然後生成一個 NodeTask,插入到對應的 ReadyQuieue。
  6. 如果沒就緒,把 output 構建一個 input_buffer,和 next 一起放入 GraphTask 的 not_ready_,後續會使用。
       1  +---------------+
func +--> | Node          |              +---> ...
          |               |              |
          |               |              |
          |  apply() +------> outputs +------> ...  2
          |               |              |
          |               |              |
          |               |              |                 +--------------+
          |               |              +---> output +--> | input_buffer +--+
          |               |                                +--------------+  |
          |               |                                                  |
          |               |                                                  |
          |               |                                                  | 5
          |               |                                                  |
          |               |                                                  |
          |               |   +----> ...                                     |
          |               |   |                                              +---------+
          |               |   |                                              |         |
          |  next_edges_+---> +----> ...  3                                  |         |
          |               |   |                                              |         |
          |               |   |                                              |         |
          |               |   |                                         5    v         |
          |               |   +----> next +------>+              YES                   |     +------------+
          +---------------+                       |             +---> push(NodeTask) +-----> | ReadyQueue |
                                                  |      4      |                      |     +------------+
                                                  |             |                      |
          +---------------+                       +--> Ready? +-+                      |
          | GraphTask     |                       |             |       6              |
          |               |                       |             | NO                   | 6
          |               |                       |             +----> next.function   |
          | dependencies_+--> map<Node*, int> +-->+                          +         |
          |               |                                                  |         |
          |               |                                                  |         |
          |               |                              6                   v         v
          | not_ready_ +--------------------------------------------->  map<Node*, InputBuffer>
          |               |
          +---------------+

手機如下:

0x06 掃尾操作

在 thread_main 之中,如果本task已經結束,即做後續操作,具體程式碼如下。

auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
  
    // 忽略前面程式碼
  
    // Check if we've completed execution.
	  if (local_graph_task->completed()) { // 判斷是否結束
      // 如果結束了,就進行後續操作
      local_graph_task->mark_as_completed_and_run_post_processing();

      auto base_owner = local_graph_task->owner_;
      // The current worker thread finish the graph_task, but the owning thread
      // of the graph_task might be sleeping on pop() if it does not have work.
      // So we need to send a dummy function task to the owning thread just to
      // ensure that it's not sleeping, so that we can exit the thread_main.
      // If it has work, it might see that graph_task->outstanding_tasks_ == 0
      // before it gets to the task, but it's a no-op anyway.
      //
      // NB: This is not necessary if the current thread is the owning thread.
      if (worker_device != base_owner) {
        // Synchronize outstanding_tasks_ with queue mutex
        std::atomic_thread_fence(std::memory_order_release);
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
      }
    }

我們接下來分析這些掃尾工作。注意,這裡是 thread_main 之中的掃尾工作

6.1 判斷結束

以下程式碼用來判斷本 GraphTask是否結束,其實就是 ReadyQueue 之中是否還有待執行的 NodeTask。

outstanding_tasks_ 是待處理 NodeTask的數量,用來判斷該GrapTask是否還需要執行,其數值總是先加再減,如果數目為0,則說明任務結束了。

  • 當 GraphTask 被建立出來時候,此數值為0。
  • 如果有一個NodeTask被送入到 ReadyQueue,則outstanding_tasks_ 增加 1。
  • 如果在工作執行緒作執行一次 evaluate_function(task)後,outstanding_tasks的值減 1。
  • 如果這個數量不為0,則此GraphTask依然需要執行。
bool GraphTask::completed() {
  // outstanding_tasks在evaluate_function中可能會被改變
  return outstanding_tasks_.load() == 0 ||
      (exit_on_error_ && has_error_.load());
}

6.2 後續&通知

mark_as_completed_and_run_post_processing 就是進行後續處理。

執行後續操作 exec_post_processing,然後使用 future_result_->markCompleted 通知主執行緒。

void GraphTask::mark_as_completed_and_run_post_processing() {
  // Allow only one thread one attempt to process this logic.
  if (future_completed_.exchange(true)) {
    // Future is already marked complete, or being marked as such.
    // In case the marking complete is only in progress, we add a
    // wait() to guarantee the future is marked complete on exit.
    future_result_->wait();
    return;
  }

  try {
    // Run post processing, before marking the future as complete.
    // Drop lock prior to completing, to avoid holding across callbacks.
    std::unique_lock<std::mutex> lock(mutex_);

    exec_post_processing(); // 進行後續操作
    std::vector<Variable> vars = std::move(captured_vars_);

    // Need to unlock before we call markCompleted to avoid holding locks
    // when the callbacks are called.
    lock.unlock();
    future_result_->markCompleted(std::move(vars));  // 通知主執行緒
  } catch (std::exception& e) {
    future_result_->setErrorIfNeeded(std::current_exception());
  }
}

6.2.1 後續操作

後續操作,如果之前有註冊了 callback,則進行呼叫。也會進行流同步。

void GraphTask::exec_post_processing() {
  if (!not_ready_.empty()) {
    throw std::runtime_error("could not compute gradients for some functions");
  }

  // set the thread_local current_graph_task_ as more callbacks can be installed
  // by existing final callbacks.
  GraphTaskGuard guard(shared_from_this());
  // Lock mutex during each iteration for accessing final_callbacks.size()
  // Unlocking is necessary, because the callback can register
  // more callbacks (or they can be registered from other threads
  // while it's waiting.
  std::unique_lock<std::mutex> cb_lock(final_callbacks_lock_);
  // WARNING: Don't use a range-for loop here because more callbacks may be
  // added in between callback calls, so iterators may become invalidated.
  for (size_t i = 0; i < final_callbacks_.size(); ++i) {
    cb_lock.unlock();
    final_callbacks_[i]();
    cb_lock.lock();
  }

  // Syncs leaf streams with default streams (if necessary)
  // See note "Streaming backwards"
  for (const auto& leaf_stream : leaf_streams) {
    const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
    const auto default_stream = guard.getDefaultStream(leaf_stream.device());
    if (leaf_stream != default_stream) {
      auto event = c10::Event{c10::DeviceType::CUDA};
      event.record(leaf_stream);
      default_stream.wait(event);
    }
  }
}

6.2.2 通知主執行緒

之前在 execute 之中會用 fut->wait() 來等待任務完成。下面我們省略了部分程式碼。

auto Engine::execute(const edge_list& roots,
                     const variable_list& inputs,
                     bool keep_graph,
                     bool create_graph,
                     bool accumulate_grad,
                     const edge_list& outputs) -> variable_list {

  
  // Queue the root
  if (skip_dummy_node) {
    execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
  } else {
    execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list()));
  }
  auto& fut = graph_task->future_result_;
  fut->wait();
  return fut->value().toTensorVector();
}

在 mark_as_completed_and_run_post_processing 會用如下程式碼來通知主執行緒。

future_result_->markCompleted(std::move(vars));  // 通知主執行緒

6.3 通知其他執行緒

如果這個task是來自其它work thread,即 worker_device != base_owner,則向那個worker thread的queue傳送一個dummy function task,讓那個工作執行緒也執行起來。

local_graph_task 表示我們從佇列中檢索的 graph_task。外部graph_ 任務表示我們需要執行的可重入執行的總體graph_任務。

在 thread_main 之中,有一個 work around。就是:當前工作執行緒完成 graph_task,但此時,擁有graph_task的執行緒可能正在pop()上等待休眠。因此,我們需要向所屬執行緒傳送一個仿造的函式任務,以喚醒它,這樣我們可以退出thread_main。

這種情況發生在可重入反向傳播的情形。

// If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
//    backward call from that device.
graph_task->owner_ = worker_device;

具體程式碼如下:

    // Check if we've completed execution.
    if (local_graph_task->completed()) {
      local_graph_task->mark_as_completed_and_run_post_processing();
      auto base_owner = local_graph_task->owner_; // 當前裝置
        
      if (worker_device != base_owner) {
          
        // 不是同一個裝置
          
        // Synchronize outstanding_tasks_ with queue mutex
        std::atomic_thread_fence(std::memory_order_release);
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0))); // dummy task
      }
    }

其他執行緒當收到了 dummy task 之後,不會處理,因為 function 是 nullptr,然後就呼叫 local_ready_queue->pop() 繼續從自己的queue 中讀取下一個 task

具體如下:

  1. 主執行緒等待。
  2. 如果工作執行緒發現GraphTask 已經結束,就通知主執行緒。
  3. 如果需要喚醒其他執行緒,就向該執行緒對應的 queue 插入 NodeTask。
  4. 對應執行緒取出 NodeTask 進行執行。
                                         +------------------------------------------------+
                                         | Worker Thread 1                                |
                                         |                                                |
                                         |  thread_main{                                  |
                                         |                                                |
                                         |     mark_as_completed_and_run_post_processing  |
                       2 markCompleted() |     {                                          |
                                 +-------------------+                                    |
                                 |       |     }                                          |
                                 |       |                                                |
+---------------+                |       |     push(NodeTask) +-----+                     |
| Main Thread   |                |       |                          |                     |
|               |                |       |   }                      |                     |
|               |                |       |                          |                     |
|               |                |       +------------------------------------------------+
|               |                |                                  |
|               |                |                                3 |
|               |                v                                  v
|               |                                           +-------+-------+
|               |   1      +----------------+               |               |
|               | wait()   |                |               |  ReadyQueue   |
|           +------------> | future_result_ |               |               |
|               |          |                |               +-------+-------+
|               |          +----------------+                       |
|               |                                                   |
|               |                                                 4 | pop(NodeTask)
|               |                                                   |
|               |                                                   v
|               |                                          +--------+---------------------+
|               |                                          | Worker Thread 2              |
|               |                                          |                              |
|               |                                          |                              |
+---------------+                                          |                              |
                                                           |                              |
                                                           |                              |
                                                           +------------------------------+

至此,後向傳播已經分析完畢,從下一篇開始,我們正式進入 PyTorch 分散式訓練。

0xFF 參考

https://www.zhihu.com/column/gemfield

【PyTorch】聊聊 backward 背後的程式碼

pytorch筆記(計算圖+autograd)-Node(1)

詳解Pytorch中的網路構造

PyTorch的優化器

PyTorch的分散式

PyTorch的Tensor(下)

PyTorch的Tensor(中)

PyTorch的Tensor(上)

PyTorch的動態圖(下)

PyTorch的動態圖(上)

PyTorch Internals 5:Autograd的實現

A GENTLE INTRODUCTION TO TORCH.AUTOGRAD

PyTorch學習筆記(12)——PyTorch中的Autograd機制介紹

PyTorch 的 Autograd

相關文章