[原始碼解析] PyTorch 分散式 Autograd (6) ---- 引擎(下)

羅西的思考發表於2021-12-06

[原始碼解析] PyTtorch 分散式 Autograd (6) ---- 引擎(下)

0x00 摘要

上文我們介紹了引擎如何獲得後向計算圖的依賴,本文我們就接著看看引擎如何依據這些依賴進行後向傳播。通過本文的學習,大家可以:

  • 瞭解 RecvRpcBackward 如何給對應的下游節點傳送 RPC 訊息,可以再次梳理一下worker之間後向傳播的互動流程。
  • 瞭解 AccumulateGrad 如何在上下文累積梯度。

PyTorch分散式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀

[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎

[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

[原始碼解析] PyTorch 分散式(1)------歷史和概述

[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)

[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)

[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念

[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用

[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store

[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組

[原始碼解析] PyTorch 分散式(8) -------- DistributedDataParallel之論文篇

[原始碼解析] PyTorch 分散式(9) ----- DistributedDataParallel 之初始化

[原始碼解析] PyTorch 分散式(10)------DistributedDataParallel 之 Reducer靜態架構

[原始碼解析] PyTorch 分散式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作

[原始碼解析] PyTorch 分散式(12) ----- DistributedDataParallel 之 前向傳播

[原始碼解析] PyTorch 分散式(13) ----- DistributedDataParallel 之 反向傳播

[原始碼解析] PyTorch 分散式 Autograd (1) ---- 設計

[原始碼解析] PyTorch 分散式 Autograd (2) ---- RPC基礎

[原始碼解析] PyTorch 分散式 Autograd (3) ---- 上下文相關

[原始碼解析] PyTorch 分散式 Autograd (4) ---- 如何切入引擎

[原始碼解析] PyTorch 分散式 Autograd (5) ---- 引擎(上)

為了更好的說明,本文程式碼會依據具體情況來進行相應精簡。

0x01 回顧

我們首先回顧FAST模式演算法演算法如下,本文需要討論後面若干部分。

  1. 我們從具有反向傳播根的worker開始(所有根都必須是本地的)。
  2. 查詢當前Distributed Autograd Context 的所有send函式 。
  3. 從提供的根和我們檢索到的所有send函式開始,我們在本地計算依賴項 。
  4. 計算依賴項後,使用提供的根來啟動本地 autograd 引擎。
  5. 當 autograd 引擎執行該recv函式時,該recv 函式通過 RPC 將輸入梯度傳送到適當的worker。每個recv函式都知道目標 worker id,因為它被記錄為前向傳播的一部分。通過autograd_context_idautograd_message_idrecv函式被髮送到遠端主機。
  6. 當遠端主機收到這個請求時,我們使用 autograd_context_idautograd_message_id來查詢適當的send函式。
  7. 如果這是worker第一次收到對給定 autograd_context_id的請求,它將按照上面的第 1-3 點所述在本地計算依賴項。
  8. 然後將在第6點接受到的send方法插入佇列,以便在該worker的本地 autograd 引擎上執行。
  9. 最後,我們不是在 Tensor的.grad之上累積梯度,而是在每個Distributed Autograd Context之上分別累積梯度 。梯度儲存在Dict[Tensor, Tensor]之中 ,Dict[Tensor, Tensor]基本上是從 Tensor 到其關聯梯度的對映,並且可以使用 get_gradients() API檢索該對映 。

其次,我們看看總體執行程式碼,總體執行是在 DistEngine::execute 之中完成,具體分為如下步驟:

  • 使用 contextId 得到前向的上下文。
  • 使用 validateRootsAndRetrieveEdges 進行驗證。
  • 構造一個GraphRoot,用它來驅動後向傳播,可以認為是一個虛擬根。
  • 使用 computeDependencies 計算依賴。
  • 使用 runEngineAndAccumulateGradients 進行反向傳播計算。
  • 使用 clearAndWaitForOutstandingRpcsAsync 等待 RPC 完成。
void DistEngine::execute(
    int64_t contextId,
    const variable_list& roots,
    bool retainGraph) {
  // Retrieve the context for the given context_id. This will throw if the
  // context_id is invalid.
  auto autogradContext =
      DistAutogradContainer::getInstance().retrieveContext(contextId);

  // Perform initial pre-processing.
  edge_list rootEdges;
  variable_list grads;
  validateRootsAndRetrieveEdges(roots, rootEdges, grads); 

  // 構造一個GraphRoot,用它來驅動後向傳播,可以認為是一個虛擬根
  std::shared_ptr<Node> graphRoot =
      std::make_shared<GraphRoot>(rootEdges, grads);
  edge_list outputEdges;
  // Compute dependencies locally, starting from all roots and all 'send'
  // functions.
  {
    std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
    // Context should not have been initialized already.
    TORCH_INTERNAL_ASSERT(
        initializedContextIds_.find(autogradContext->contextId()) ==
        initializedContextIds_.end());

    // 計算依賴
    computeDependencies(
        autogradContext, rootEdges, grads, graphRoot, outputEdges, retainGraph);

    // Mark the autograd context id as initialized.
    initializedContextIds_.insert(autogradContext->contextId());
  }

  BackwardPassCleanupGuard guard(autogradContext);

  // This needs to be blocking and as a result we wait for the future to
  // complete.
  runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges)
      ->waitAndThrow(); // 反向傳播計算

  // Wait for all of the outstanding rpcs to complete.
  autogradContext->clearAndWaitForOutstandingRpcsAsync()->waitAndThrow();
}

再次,從前文我們知道,依賴項已經在 computeDependencies 之中處理完畢,所有需要計算的函式資訊都位於 GraphTask.exec_info_ 之上。我們接下來就看看如何計算,就是 runEngineAndAccumulateGradients 和 clearAndWaitForOutstandingRpcsAsync 這兩個方法。

0x02 執行GraphTask

我們首先看看如何使用 runEngineAndAccumulateGradients 進行反向傳播計算,累積梯度。

2.1 runEngineAndAccumulateGradients

引擎之中,首先呼叫了 runEngineAndAccumulateGradients。主要是封裝了一個 NodeTask,然後以此呼叫 execute_graph_task_until_ready_queue_empty。其中使用 at::launch 來啟動執行緒。

c10::intrusive_ptr<c10::ivalue::Future> DistEngine::
    runEngineAndAccumulateGradients(
        const ContextPtr& autogradContext,
        const std::shared_ptr<Node>& graphRoot,
        const edge_list& outputEdges,
        bool incrementOutstandingTasks) {
  // Cleanup previous state for outstanding RPCs. Outstanding RPCs could be
  // lingering if we're running backward multiple times and some of the
  // passes ran into errors.
  autogradContext->clearOutstandingRpcs();
    
  // 得到GraphTask
  auto graphTask = autogradContext->retrieveGraphTask();
  
  // 啟動了一個執行緒來執行 execute_graph_task_until_ready_queue_empty
  at::launch([this, graphTask, graphRoot, incrementOutstandingTasks]() {
    execute_graph_task_until_ready_queue_empty(
        /*node_task*/ NodeTask(graphTask, graphRoot, InputBuffer(0)),
        /*incrementOutstandingTasks*/ incrementOutstandingTasks);
  });
    
  // Use a reference here to avoid refcount bump on futureGrads.
  // 處理結果
  auto& futureGrads = graphTask->future_result_;

  // Build a future that waits for the callbacks to execute (since callbacks
  // execute after the original future is completed). This ensures we return a
  // future that waits for all gradient accumulation to finish.
  auto accumulateGradFuture =
      c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());

  futureGrads->addCallback(
      [autogradContext, outputEdges, accumulateGradFuture](c10::ivalue::Future& futureGrads) {
        if (futureGrads.hasError()) {
		  // 省略錯誤處理部分
          return;
        }

        try {
          const variable_list& grads =
              futureGrads.constValue().toTensorVector();
           // 標識已經結束 
          accumulateGradFuture->markCompleted(c10::IValue());
        } catch (std::exception& e) {
          accumulateGradFuture->setErrorIfNeeded(std::current_exception());
        }
      });

  return accumulateGradFuture;
}

at::launch 位於 aten/src/ATen/ParallelThreadPoolNative.cpp,這裡會線上程之中呼叫傳入的 func。

void launch(std::function<void()> func) {

  internal::launch_no_thread_state(std::bind([](
    std::function<void()> f, ThreadLocalState thread_locals) {
      ThreadLocalStateGuard guard(std::move(thread_locals));
      f();
    },
    std::move(func),
    ThreadLocalState()
  ));
}

namespace internal {
    void launch_no_thread_state(std::function<void()> fn) {
    #if AT_EXPERIMENTAL_SINGLE_THREAD_POOL
      intraop_launch(std::move(fn));
    #else
      get_pool().run(std::move(fn));
    #endif
    }
} 

我們接下來一一看看內部這幾個方法如何執行。

2.2 execute_graph_task_until_ready_queue_empty

此函式類似 Engine::thread_main,通過一個 NodeTask 來完成本 GraphTask的執行,其中 evaluate_function 會不停的向 cpu_ready_queue 插入新的 NodeTask。engine_.evaluate_function 方法會:

  • 首先,初始化原生引擎執行緒。
  • 其次,每個呼叫建立一個 cpu_ready_queue,用來從root_to_execute開始遍歷graph_task,這允許用不同的執行緒來對GraphTask並行執行,這是一個CPU相關的queue。
  • 把傳入的 node_task 插入到 cpu_ready_queue。
  • 沿著反向計算圖從根部開始,一直計算到葉子節點。
    • 這裡葉子節點都是 AccumulateGrad 或者 RecvRpcBackward。

    • 如果是中間節點,則正常計算。

    • 如果是 RecvRpcBackward 則會給對應的下游節點傳送 RPC 訊息

    • 如果是 AccumulateGrad,則在上下文累積梯度

具體程式碼如下:

void DistEngine::execute_graph_task_until_ready_queue_empty(
    NodeTask&& node_task,
    bool incrementOutstandingTasks) {
  
  // 初始化原生引擎執行緒
  engine_.initialize_device_threads_pool();
  
  // Create a ready queue per call to traverse the graph_task from
  // root_to_execute This allow concurrent execution of the same GraphTask from
  // different threads
  // 每個呼叫建立一個 ready queue,用來從root_to_execute開始遍歷graph_task,這允許用不同的執行緒來對GraphTask並行執行,這是一個CPU相關的queue
  std::shared_ptr<ReadyQueue> cpu_ready_queue = std::make_shared<ReadyQueue>();
  auto graph_task = node_task.base_.lock();
  if (graph_task == nullptr) {
    LOG(ERROR) << "GraphTask has expired for NodeTask: "
               << node_task.fn_->name() << ", skipping execution.";
    return;
  }

  cpu_ready_queue->push(std::move(node_task), incrementOutstandingTasks);

  torch::autograd::set_device(torch::autograd::CPU_DEVICE);
  graph_task->owner_ = torch::autograd::CPU_DEVICE;
  while (!cpu_ready_queue->empty()) {
    std::shared_ptr<GraphTask> local_graph_task;
    {
      // Scope this block of execution since NodeTask is not needed after this
      // block and can be deallocated (release any references to grad tensors
      // as part of inputs_)
      NodeTask task = cpu_ready_queue->pop(); // 取出一個NodeTask
      if (!(local_graph_task = task.base_.lock())) {
        continue;
      }
      if (task.fn_ && !local_graph_task->has_error_.load()) {
        AutoGradMode grad_mode(local_graph_task->grad_mode_);
        try {
          GraphTaskGuard guard(local_graph_task);
          engine_.evaluate_function( // 這裡會呼叫具體Node對應的函式
              local_graph_task, task.fn_.get(), task.inputs_, cpu_ready_queue);
        } catch (std::exception& e) {
          engine_.thread_on_exception(local_graph_task, task.fn_, e);
          // break the loop in error so that we immediately stop the execution
          // of this GraphTask, mark it completed if necessary and return the
          // future with proper ErrorMessage
          break;
        }
      }
    }
    // Decrement the outstanding task.
    --local_graph_task->outstanding_tasks_; // 處理了一個NodeTask
  }
  // Check if we've completed execution.
  if (graph_task->completed()) {
    // We don't need to explicitly notify the owner thread, since
    // 'mark_as_completed_and_run_post_processing' would mark the Future as
    // completed and this would notify the owner thread that the task has been
    // completed.
    graph_task->mark_as_completed_and_run_post_processing();
  }
}

另外,一共有三個地方呼叫 execute_graph_task_until_ready_queue_empty。

  1. runEngineAndAccumulateGradients 會呼叫,這裡就是使用者主動呼叫 backward 的情形,就是本節介紹的。
  2. executeSendFunctionAsync 會呼叫,這裡對應了某節點從反向傳播上一節點接受到梯度之後的操作,我們會在下一節介紹。
  3. globalCpuThread 會呼叫,這是CPU工作專用執行緒,我們馬上會介紹。
  4. 在 Engine.evaluate_function 之中,會針對 AccumulateGrad 來累積梯度。
  5. 在 Engine.evaluate_function 之中,會呼叫 RecvRpcBackward 來向反向傳播下游傳送訊息。

我們總結一下幾個計算梯度的流程,分別對應下面三個數字。

 User Training Script             RPC BACKWARD_AUTOGRAD_REQ
     +                                         +
     |                                         |
     | 1                                       | 2
     v                                         v
 backward                         RequestCallbackNoPython.processRpc
     +                                         +
     |                                         |
     |                                         |
     v                                         v
 DistEngine.execute               RequestCallbackNoPython.processBackwardAutogradReq
     +                                         +
     |                                         |
     |                                         |
     |                                         v
     |              +----------+  DistEngine.executeSendFunctionAsync
     |              |                               +
     |              |                               |
     v              v                               |
DistEngine.computeDependencies                      |
     |                                              |
     |                                              |
     v                                              |
 DistEngine.runEngineAndAccumulateGradients         |     DistEngine.globalCpuThread
     +                                              |                   +
     |                           +------------------+                   |
     |                           |                                      | 3
     |                           |             +------------------------+
     |                           |             |
     |                           |             |
     v                           v             v
 DistEngine.execute_graph_task_until_ready_queue_empty
     +
     |
     |
     v
 DistEngine.evaluate_function
     +
     |
     +--------------------------------------------------------------+
     |                                                              |
     |  4 AccumulateGrad                                            | 5  RecvRpcBackward
     v                                                              v

(*hook)(captured_grad)                            call_function(graph_task, func, inputs)

2.3 evaluate_function

上面程式碼之中,實際上會呼叫原生引擎的 evaluate_function 來完成操作。

我們看看如何使用 exec_info_,如果沒有設定為需要執行,則就不處理。在此處,我們可以看到 上文提到的recvBackwardEdges 如何與 exec_info_ 互動。

遍歷 recvBackwardEdges,對於每個 recvBackward,在 GraphTask.exec_info_ 之中對應項之上設止為需要執行。

具體程式碼如下,這裡會:

  • 針對 AccumulateGrad 來累積梯度。
  • 呼叫 RecvRpcBackward 來向反向傳播下游傳送訊息。
void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputs,
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!exec_info_.empty()) {
    auto& fn_info = exec_info_.at(func);
    if (auto* capture_vec = fn_info.captures_.get()) {
      // Lock mutex for writing to graph_task->captured_vars_.
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      for (const auto& capture : *capture_vec) {
        auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
        captured_grad = inputs[capture.input_idx_];
        for (auto& hook : capture.hooks_) {
          captured_grad = (*hook)(captured_grad); //這裡呼叫 hook,就是 DistAccumulateGradCaptureHook 的 operator(),captured_grad 就是累積的梯度
        }
      }
    }
    if (!fn_info.needed_) { 
      // Skip execution if we don't need to execute the function.
      return; // 如果沒有設定需要執行,則直接返回。recvBackward 會設定需要執行
    }
  }
  
  // 這裡就是呼叫 recvBackward
  auto outputs = call_function(graph_task, func, inputs);
    
  // 後續程式碼省略  

2.4 globalCpuThread

globalCpuThread 可以參見上文的 [GPU to CPU continuations] 一節,globalCpuThread是工作執行緒,其就是從 ready queue 裡面彈出 NodeTask,然後執行。

對於globalCpuThread,其引數 ready_queue 是 global_cpu_ready_queue_

void DistEngine::globalCpuThread(
    const std::shared_ptr<ReadyQueue>& ready_queue) {
  while (true) {
    NodeTask task = ready_queue->pop();
    if (task.isShutdownTask_) {
      // Need to shutdown this thread.
      break;
    }

    auto graphTask = task.base_.lock();
    if (graphTask == nullptr) {
      // GraphTask has expired, ignore and continue processing.
      continue;
    }

    // Launch the execution on a JIT thread.
    at::launch([this,
                graphTask,
                graphRoot = task.fn_,
                variables =
                    InputBuffer::variables(std::move(task.inputs_))]() mutable {
      InputBuffer inputs(variables.size());
      for (size_t i = 0; i < variables.size(); i++) {
        inputs.add(i, std::move(variables[i]), c10::nullopt, c10::nullopt);
      }
      execute_graph_task_until_ready_queue_empty( // 這裡會呼叫
          /*node_task*/ NodeTask(graphTask, graphRoot, std::move(inputs)),
          /*incrementOutstandingTasks*/ false);
    });
  }
}

對於普通引擎也會設定一個 cpu 專用 queue。

auto graph_task = std::make_shared<GraphTask>(
    /* keep_graph */ keep_graph,
    /* create_graph */ create_graph,
    /* depth */ not_reentrant_backward_call ? 0 : total_depth + 1,
    /* cpu_ready_queue */ local_ready_queue);

2.5 小結

對於分散式引擎,與普通引擎在計算部分主要不同之處為:

  • 如果是 RecvRpcBackward 則會給對應的下游節點傳送 RPC 訊息

  • 如果是 AccumulateGrad,則在上下文累積梯度

所以我們接下來看看具體這兩部分如何處理。

0x03 RPC呼叫

在之前文章中,我們看到了接受方如何處理反向傳播 RPC 呼叫,我們接下來看看引擎如何發起反向傳播 RPC 呼叫,就是如何呼叫 recv 方法。

這裡就適用於下面worker 0 呼叫 recv ,執行來到 worker 1 這種情況,對應設計文件中如下。

當 autograd 引擎執行該recv函式時,該recv 函式通過 RPC 將輸入梯度傳送到適當的worker。每個recv函式都知道目標 worker id,因為它被記錄為前向傳播的一部分。通過autograd_context_idautograd_message_idrecv函式被髮送到遠端主機。

img

我們就看看如何執行 recv 函式。

具體結合到分散式引擎,就是當引擎發現某一個 Node 是 RecvRpcBackward,就呼叫其 apply 函式

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputs,
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!exec_info_.empty()) {
    // 省略了梯度累積部分程式碼,具體可以參見上面章節 
    if (!fn_info.needed_) { 
      // Skip execution if we don't need to execute the function.
      return; // 如果沒有設定需要執行,則直接返回。recvBackward 會設定需要執行
    }
  }
  
  // 這裡就是呼叫 recvBackward.apply 函式
  auto outputs = call_function(graph_task, func, inputs);
    
  // 後續程式碼省略  

3.1 RecvRpcBackward

3.1.1 定義

RecvRpcBackward 定義如下,

class TORCH_API RecvRpcBackward : public torch::autograd::Node {
 public:
  explicit RecvRpcBackward(
      const AutogradMetadata& autogradMetadata,
      std::shared_ptr<DistAutogradContext> autogradContext,
      rpc::worker_id_t fromWorkerId,
      std::unordered_map<c10::Device, c10::Device> deviceMap);

  torch::autograd::variable_list apply(
      torch::autograd::variable_list&& grads) override;

 private:
  const AutogradMetadata autogradMetadata_;

  // Hold a weak reference to the autograd context to avoid circular
  // dependencies with the context (since it holds a reference to
  // RecvRpcBackward).
  std::weak_ptr<DistAutogradContext> autogradContext_;

  // The worker id from which the RPC was received. During the backward pass,
  // we need to propagate the gradients to this workerId.
  rpc::worker_id_t fromWorkerId_;

  // Device mapping for tensors sent over RPC.
  const std::unordered_map<c10::Device, c10::Device> deviceMap_;
};

3.1.2 構建

建構函式如下。

RecvRpcBackward::RecvRpcBackward(
    const AutogradMetadata& autogradMetadata,
    ContextPtr autogradContext,
    rpc::worker_id_t fromWorkerId,
    std::unordered_map<c10::Device, c10::Device> deviceMap)
    : autogradMetadata_(autogradMetadata),
      autogradContext_(std::move(autogradContext)),
      fromWorkerId_(fromWorkerId),
      deviceMap_(std::move(deviceMap)) {}

3.1.3 apply

torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp 定義了其 apply 函式,其作用就是:

  • 把傳入的梯度 grads 放入outputGrads,因為要輸出給下一環節。
  • 構建 PropagateGradientsReq,這就是 BACKWARD_AUTOGRAD_REQ。
  • 傳送 RPC 給下一環節。
variable_list RecvRpcBackward::apply(variable_list&& grads) {
  std::vector<Variable> outputGrads;
  for (size_t i = 0; i < grads.size(); i++) { // 下面就是把傳入的梯度 grads 放入outputGrads
    const auto& grad = grads[i];
    if (grad.defined()) {
      outputGrads.emplace_back(grad);
    } else {
      // Put in zeros for a tensor with no grad.
      outputGrads.emplace_back(input_metadata(i).zeros_like());
    }
  }
 
  auto sharedContext = autogradContext_.lock();
  // Send the gradients over the wire and record the future in the autograd
  // context.
  PropagateGradientsReq gradCall( // 構建 PropagateGradientsReq
      autogradMetadata_,
      outputGrads,
      sharedContext->retrieveGraphTask()->keep_graph_);

  // Send the gradients over to the appropriate node.
  auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
  auto jitFuture = rpcAgent->send( // 傳送 RPC
      rpcAgent->getWorkerInfo(fromWorkerId_),
      std::move(gradCall).toMessage(), // 呼叫了toMessageImpl
      rpc::kUnsetRpcTimeout,
      deviceMap_);

  // Record the future in the context.
  sharedContext->addOutstandingRpc(jitFuture);

  // 'recv' function sends the gradients over the wire using RPC, it doesn't
  // need to return anything for any downstream autograd function.
  return variable_list();
}

因為這裡傳送了 PropagateGradientsReq,所以我們接著看。

3.2 PropagateGradientsReq

3.2.1 定義

PropagateGradientsReq 擴充套件了 RpcCommandBase。

// Used to propagate gradients from one node to another during a distributed
// backwards pass. This RPC call is invoked when we hit a `recv` autograd
// function during backward pass execution.
class TORCH_API PropagateGradientsReq : public rpc::RpcCommandBase {
 public:
  PropagateGradientsReq(
      const AutogradMetadata& autogradMetadata,
      std::vector<torch::autograd::Variable> grads,
      bool retainGraph = false);

  const AutogradMetadata& getAutogradMetadata();

  const std::vector<torch::autograd::Variable>& getGrads();

  // Serialization and deserialization methods.
  rpc::Message toMessageImpl() && override;
  static std::unique_ptr<PropagateGradientsReq> fromMessage(
      const rpc::Message& message);

  // Whether or not to retain the autograd graph.
  bool retainGraph();

 private:
  AutogradMetadata autogradMetadata_;
  std::vector<torch::autograd::Variable> grads_;
  bool retainGraph_;
};

其 toMessageImpl 指明瞭本訊息是 BACKWARD_AUTOGRAD_REQ。

Message PropagateGradientsReq::toMessageImpl() && {
  std::vector<at::IValue> ivalues;
  // Add all the grad tensors.
  for (const auto& grad : grads_) {
    ivalues.emplace_back(grad);
  }

  // Now add autograd metadata.
  ivalues.emplace_back(autogradMetadata_.autogradContextId);
  ivalues.emplace_back(autogradMetadata_.autogradMessageId);

  // Add retain graph.
  ivalues.emplace_back(retainGraph_);

  // Now pickle using JIT pickler.
  std::vector<torch::Tensor> tensorTable;
  std::vector<char> payload =
      jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);

  return Message(
      std::move(payload),
      std::move(tensorTable),
      MessageType::BACKWARD_AUTOGRAD_REQ); // 這裡指明瞭訊息型別。
}

3.3 接受方

為了論述完整,我們接下來看看接收方如何處理反向傳播。

3.3.1 接受訊息

在生成 TensorPipeAgent 時候,把 RequestCallbackImpl 配置為回撥函式。這是 agent 的統一響應函式。前面關於代理接收邏輯時候,我們也提到了,會進入 RequestCallbackNoPython::processRpc 函式。其中可以看到有對 BACKWARD_AUTOGRAD_REQ 的處理邏輯。

這種是 RPC 的正常流程。

void RequestCallbackNoPython::processRpc(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {

  switch (messageType) {

    case MessageType::BACKWARD_AUTOGRAD_REQ: { 
      processBackwardAutogradReq(rpc, messageId, responseFuture); // 這裡呼叫
      return;
    };

3.3.2 processBackwardAutogradReq

在 processBackwardAutogradReq 之中會:

  • 獲取 DistAutogradContainer。
  • 獲取 上下文。
  • 呼叫 executeSendFunctionAsync 進行引擎處理。

由此,我們可以看到有兩個途徑進入引擎:

  • 一個是示例程式碼顯式主動呼叫 backward,進而呼叫到 DistEngine::getInstance().execute,就是 worker 0。
  • 一個是被動呼叫 DistEngine::getInstance().executeSendFunctionAsync,就是 worker 1。
void RequestCallbackNoPython::processBackwardAutogradReq(
    RpcCommandBase& rpc,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture) const {
  auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
  const auto& autogradMetadata = gradientsCall.getAutogradMetadata();

  // Retrieve the appropriate autograd context.
  auto autogradContext = DistAutogradContainer::getInstance().retrieveContext(
      autogradMetadata.autogradContextId); // 得到傳送者的context id

  // Lookup the appropriate 'send' function to enqueue.
  std::shared_ptr<SendRpcBackward> sendFunction = // 依據傳送者context id和訊息id得到sendFunction
      autogradContext->retrieveSendFunction(autogradMetadata.autogradMessageId);

  // Attach the gradients to the send function.
  sendFunction->setGrads(gradientsCall.getGrads()); // 設定梯度

  // Now execute the autograd graph using the "distributed engine."
  auto execFuture = DistEngine::getInstance().executeSendFunctionAsync( // 呼叫引擎
      autogradContext, sendFunction, gradientsCall.retainGraph());

  // Our response is satisfied when the rpcs come back.
  execFuture->addCallback([responseFuture, messageId](JitFuture& execFuture) {
    if (!execFuture.hasError()) {
      Message m = std::move(PropagateGradientsResp()).toMessage();
      m.setId(messageId);
      responseFuture->markCompleted(
          IValue(c10::make_intrusive<Message>(std::move(m))));
    } else {
      responseFuture->setError(execFuture.exception_ptr());
    }
  });
}

3.3.3 executeSendFunctionAsync

executeSendFunctionAsync 這裡開始進入了引擎,注意,這裡是接收方也進入了引擎,在接收方上進行計算。executeSendFunctionAsync 會直接呼叫 execute_graph_task_until_ready_queue_empty,也可能先計算依賴然後繼續執行。此處可以參考設計之中的:

  • 6)當遠端主機收到這個請求時,我們使用 autograd_context_idautograd_message_id來查詢適當的send函式。
  • 7)如果這是worker第一次收到對給定 autograd_context_id的請求,它將按照上面的第 1-3 點所述在本地計算依賴項。
  • 8)然後將在第6點接受到的send方法插入佇列,以便在該worker的本地 autograd 引擎上執行。

具體程式碼如下:

c10::intrusive_ptr<c10::ivalue::Future> DistEngine::executeSendFunctionAsync(
    const ContextPtr& autogradContext,
    const std::shared_ptr<SendRpcBackward>& sendFunction,
    bool retainGraph) {

  // Typically the local autograd engine ensures stream synchronizations between
  // nodes in the graph. However, for distributed autograd the sendFunction
  // inputs might have been retrieved over the wire on a separate stream and the
  // sendFunction itself runs on a different stream. As a result, we need to
  // manually synchronize those two streams here.
  const auto& send_backward_stream = sendFunction->stream(c10::DeviceType::CUDA);
  if (send_backward_stream) { // 拿到本次執行對應的Stream
    for (const auto& grad : sendFunction->getGrads()) {
        const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
        const auto default_stream = guard.getStream(grad.device());
        if (send_backward_stream != default_stream) {
          auto event = c10::Event{c10::DeviceType::CUDA};
          event.record(default_stream);
          send_backward_stream->wait(event); // 需要同步,保證當前操作完成
        }
    }
  }

  std::unique_lock<std::mutex> lock(initializedContextIdsLock_);
  if (initializedContextIds_.find(autogradContext->contextId()) ==
      initializedContextIds_.end()) { // 遍歷,查詢sendFunction對應的上下文是否在本節點之中已經記錄
    // 沒有找到上下文,需要計算依賴
    edge_list outputEdges;
    // Pass in a dummy graphRoot since all send functions are the roots.
    auto dummyRoot = std::make_shared<GraphRoot>(edge_list(), variable_list());
    computeDependencies( // 計算依賴
        autogradContext, {}, {}, dummyRoot, outputEdges, retainGraph);

    // Mark the autograd context id as initialized and unlock.
    initializedContextIds_.insert(autogradContext->contextId());
    lock.unlock();

    // Enqueue the current send function.
    auto graphTask = autogradContext->retrieveGraphTask();
    // Run the autograd engine.
    auto accumulateGradFuture = runEngineAndAccumulateGradients( // 計算梯度
        autogradContext,
        sendFunction,
        outputEdges,
        /*incrementOutstandingTasks=*/false);

    // Build the 'uber' future that waits for everything.
    auto callbackFuture =
        c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
    // 註冊回撥
    accumulateGradFuture->addCallback([autogradContext,
                                       callbackFuture](c10::ivalue::Future& accumulateGradFuture) {
      try {
        if (accumulateGradFuture.hasError()) {
          // Perform cleanup at the end of the backward pass (before we mark
          // the future as completed).
          DistEngine::getInstance().cleanupBackwardPass(autogradContext);

          // Skip any further processing on errors.
          callbackFuture->setError(accumulateGradFuture.exception_ptr());
          return;
        }

        // Wait for all RPCs after the autograd engine is done.
        auto rpcFuture = autogradContext->clearAndWaitForOutstandingRpcsAsync();
        rpcFuture->addCallback([callbackFuture, autogradContext](c10::ivalue::Future& rpcFuture) {
          try {
            // Perform cleanup at the end of the backward pass (before
            // we mark the future as completed).
            DistEngine::getInstance().cleanupBackwardPass(autogradContext);
          } catch (std::exception& e) {
            callbackFuture->setErrorIfNeeded(std::current_exception());
            return;
          }

          // Finally mark the 'uber' future as completed.
          if (!rpcFuture.hasError()) {
            callbackFuture->markCompleted(c10::IValue());
          } else {
            callbackFuture->setError(rpcFuture.exception_ptr());
          }
        });
      } catch (std::exception& e) {
        callbackFuture->setErrorIfNeeded(std::current_exception());
      }
    });

    // Return the future which waits for all async processing to be done.
    return callbackFuture;
  } else { // 可以在當前Node找到上下文
    lock.unlock();
    auto graphTask = autogradContext->retrieveGraphTask();
    at::launch([this, graphTask, sendFunction]() {
      execute_graph_task_until_ready_queue_empty(
          /*node_task*/ NodeTask(graphTask, sendFunction, InputBuffer(0)),
          /*incrementOutstandingTasks*/ false);
    });
    auto fut = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
    fut->markCompleted(c10::IValue());
    return fut;
  }
}

具體如下圖:

                                                                  +
                                                         worker 0 | worker 1
                                                                  |
  Engine            RecvRpcBackward              RpcAgent         |     RequestCallbackNoPython             DistEngine
    +                    +                          +             |              +                              +
    |                    |                          |             |              |                              |
    |                    |                          |             |              |                              |
evaluate_function        |                          |             |              |                              |
    +                    |                          |             |              |                              |
    |                    |                          |             |              |                              |
    +                    |                          |             |              |                              |
  call_function          |                          |             |              |                              |
    +                    |                          |             |              |                              |
    |      grads         v                          |             |              |                              |
    +----------------> apply                        |             |              |                              |
    |                    +                          |             |              |                              |
    |                    |                          |             |              |                              |
    |                    +                          |             |              |                              |
    |                 gradCall                      |             |              |                              |
    |                    +                          |             |              |                              |
    |                    |  PropagateGradientsReq   |             |              |                              |
    |                    +------------------------> |             |              |                              |
    |                    |                          |             +              |                              |
    |                    |                          +   BACKWARD_AUTOGRAD_REQ    |                              |
    |                    |                        send  +---------+--------->    |                              |
    |                    |                          +             |              |                              |
    |                    |                          |             |              +                              |
    |                    |                          |             |     processBackwardAutogradReq              |
    |                    |                          |             |              +                              |
    |                    |                          |             |              |                              +
    |                    |                          |             |              +------------> executeSendFunctionAsync
    |                    |                          |             |              |                              +
    |                    |                          |             |              |                              |
    |                    |                          |             |              |                              |
    v                    v                          v             +              v                              v


手機如下:

0x04 DistAccumulateGradCaptureHook

目前看起來總體邏輯已經完成了,但是實際上缺了一塊,對應了設計文件中的:

最後,我們不是在 Tensor的.grad之上累積梯度,而是在每個Distributed Autograd Context之上分別累積梯度 。梯度儲存在Dict[Tensor, Tensor]之中 ,Dict[Tensor, Tensor]基本上是從 Tensor 到其關聯梯度的對映,並且可以使用 get_gradients() API檢索該對映 。

就是把異地/本地的梯度累積到本地上下文之中,所以我們再分析一下 DistAccumulateGradCaptureHook。

4.1 定義

DistAccumulateGradCaptureHook 有三個作用:

  1. 呼叫原始AccumulateGrad的 pre hooks 來修改輸入梯度。

  2. 將 grad 累積到RPC上下文。

  3. 呼叫原始AccumulateGrad的 post hooks。

其定義如下:

// This hook does 3 things:
//   1. Call pre hooks of the original AccumulateGrad to modify the input grad.
//   2. Accumuate the gard to RPC context.
//   3. Call post hooks of the original AccumulateGrad.
class DistAccumulateGradCaptureHook
    : public GraphTask::ExecInfo::Capture::GradCaptureHook {
 public:
  DistAccumulateGradCaptureHook(
      std::shared_ptr<AccumulateGrad> accumulateGrad,
      ContextPtr autogradContext)
      : accumulateGrad_(std::move(accumulateGrad)),
        autogradContext_(std::move(autogradContext)) {}

  at::Tensor operator()(const at::Tensor& grad) override {
    ThreadLocalDistAutogradContext contextGuard{ContextPtr(autogradContext_)};
    variable_list inputGrads = {grad};
    // It's intended that pre/post hooks are still called even if the grad is
    // undenfined here.
    for (const auto& hook : accumulateGrad_->pre_hooks()) {
      inputGrads = (*hook)(inputGrads); // 呼叫 pre-hooks
    }

    // It is possible that the grad is not defined since a separate
    // invocation of the autograd engine on the same node might actually
    // compute this gradient.
    if (inputGrads[0].defined()) {
      // There are 3 internal references to 'inputGrads[0]' at this moment:
      //   1. 'inputGrads[0]' in this function.
      //   2. 'graph_task->captured_vars_' on the callsite in the local engine.
      //   3. 'InputBuffer& inputs' on the callsite as the inputs of the
      //   function node.
      autogradContext_->accumulateGrad( // 累積梯度
          accumulateGrad_->variable, inputGrads[0], 3 /* num_expected_refs */);
    }
    const variable_list kEmptyOuput;
    for (const auto& hook : accumulateGrad_->post_hooks()) {
      (*hook)(kEmptyOuput, inputGrads); // 呼叫 post-hooks
    }
    return inputGrads[0];
  }

 private:
  std::shared_ptr<AccumulateGrad> accumulateGrad_; // 這就是需要累積的目標向量,後續操作在其之上
  ContextPtr autogradContext_;
};

4.2 生成

如何生成 DistAccumulateGradCaptureHook?計算依賴時候生成 DistAccumulateGradCaptureHook,但是記錄在 capture.hooks_.push_back 之中。

這裡是為了處理 AccumulateGrad。

  • AccumulateGrad 一定是葉子節點,不需執行,而需要在其上積累梯度,但是RecvRpcBackward需要執行。

  • AccumulateGrad 就儲存在 DistAccumulateGradCaptureHook 之中。

void DistEngine::computeDependencies(
    const ContextPtr& autogradContext,
    const edge_list& rootEdges,
    const variable_list& grads,
    const std::shared_ptr<Node>& graphRoot,
    edge_list& outputEdges,
    bool retainGraph) {
  
  if (!outputEdges.empty()) {
    // Compute 'needed execution' starting from all 'send' functions and the
    // original graphRoot.
    edge_list edges;
    // Create some dummy edges (input_nr not important for init_to_execute).
    for (const auto& mapEntry : sendFunctions) {
      edges.emplace_back(mapEntry.second, 0);
    }

    // Add the original graphRoot as an edge.
    edges.emplace_back(graphRoot, 0);

    // Create a dummy GraphRoot and run init_to_execute with it.
    GraphRoot dummyRoot(edges, {});
    graphTask->init_to_execute(dummyRoot, outputEdges, /*accumulate_grad=*/false, /*min_topo_nr=*/0);
    for (auto& mapEntry : graphTask->exec_info_) {
      auto& execInfo = mapEntry.second;
      if (!execInfo.captures_) {
        continue;
      }
      auto fn = mapEntry.first;
      // There may be nodes other than 'AccumulateGrad', e.g. RecvRPCBackward,
      // to be captured.
      if (auto accumulateGradFn = dynamic_cast<AccumulateGrad*>(fn)) {
        for (auto& capture : *execInfo.captures_) {
          capture.hooks_.push_back( // 這裡會生成
              std::make_unique<DistAccumulateGradCaptureHook>(
                  std::dynamic_pointer_cast<AccumulateGrad>( // 會儲存 AccumulateGrad
                      accumulateGradFn->shared_from_this()),
                  autogradContext));
        }
      }
    }

    // Mark all 'RecvRPCBackward' as needing execution.
    for (const auto& recvBackwardEdge : recvBackwardEdges) {
      graphTask->exec_info_[recvBackwardEdge.function.get()].needed_ = true;
    }
  }  
}

4.3 使用

程式碼是縮減版。

首先,execute_graph_task_until_ready_queue_empty 會呼叫到原始引擎 engine_.evaluate_function。

void DistEngine::execute_graph_task_until_ready_queue_empty(
    NodeTask&& node_task,
    bool incrementOutstandingTasks) {

  while (!cpu_ready_queue->empty()) {
    std::shared_ptr<GraphTask> local_graph_task;
    {
      NodeTask task = cpu_ready_queue->pop();

      if (task.fn_ && !local_graph_task->has_error_.load()) {
        AutoGradMode grad_mode(local_graph_task->grad_mode_);
        GraphTaskGuard guard(local_graph_task);
        engine_.evaluate_function( // 呼叫原始引擎
              local_graph_task, task.fn_.get(), task.inputs_, cpu_ready_queue);
      }
    }
    // Decrement the outstanding task.
    --local_graph_task->outstanding_tasks_;
  }

}

其次,原始引擎程式碼之中,會呼叫hooks。

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputs,
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!exec_info_.empty()) {
    auto& fn_info = exec_info_.at(func);
    if (auto* capture_vec = fn_info.captures_.get()) {
      // Lock mutex for writing to graph_task->captured_vars_.
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      for (const auto& capture : *capture_vec) {
        auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
        captured_grad = inputs[capture.input_idx_];
        for (auto& hook : capture.hooks_) {
          captured_grad = (*hook)(captured_grad); // 這裡呼叫 hook,就是 DistAccumulateGradCaptureHook 的 operator(),captured_grad 就是累積的梯度
        }
      }
    }
  }
  
  // 後續省略

DistAccumulateGradCaptureHook 的 operator() 方法之中,會呼叫下面來累積梯度。

  autogradContext_->accumulateGrad(
      accumulateGrad_->variable, inputGrads[0], 3 /* num_expected_refs */);

4.4 累積梯度

4.4.1 上下文累積

void DistAutogradContext::accumulateGrad(
    const torch::autograd::Variable& variable, // variable就是目標變數
    const torch::Tensor& grad, // grad就是梯度,需要累積到variable之上
    size_t num_expected_refs) {

  std::lock_guard<std::mutex> guard(lock_);
  auto it = accumulatedGrads_.find(variable);
  at::Tensor old_grad;
  if (it != accumulatedGrads_.end()) {
    // Accumulate multiple grads on the same variable.
    old_grad = it->value();
  }

  // Gradients are computed using the forward streams. Local autograd
  // engine uses AccumulateGrad function to retrieve and apply forward
  // stream during the backward computation. In distributed autograd,
  // we directly call AccumulateGrad::accumulateGrad, and skip the
  // CUDA stream restoration from autograd function. Hence, we manually
  // call it here to get the streams correct.
  auto forward_stream =
      torch::autograd::impl::grad_accumulator(variable)->stream(
          grad.device().type());
  c10::OptionalStreamGuard stream_guard(forward_stream);

  // No higher order gradients supported in distributed autograd.
  AutoGradMode grad_mode(false);
  at::Tensor new_grad = AccumulateGrad::callHooks(variable, grad); // 計算

  AccumulateGrad::accumulateGrad( // 呼叫運算元函式來累積梯度
      variable,
      old_grad,
      new_grad,
      // Add +1 here since we can't std::move(grad) when call
      // AccumulateGrad::callHooks, since it is a const ref, and that incurs a
      // refcount bump for the new_grad.
      num_expected_refs + 1,
      [this, &variable](at::Tensor&& grad_update) {
        auto device = grad_update.device();
        accumulatedGrads_.insert(variable, std::move(grad_update));
        recordGradEvent(device);
      });
}

4.4.2 運算元累積

程式碼位於 torch/csrc/autograd/functions/accumulate_grad.h。AccumulateGrad 的定義如下:

struct TORCH_API AccumulateGrad : public Node {
  explicit AccumulateGrad(Variable variable_);

  variable_list apply(variable_list&& grads) override;

  static at::Tensor callHooks(
      const Variable& variable,
      at::Tensor new_grad) {
    for (auto& hook : impl::hooks(variable)) {
      new_grad = (*hook)({new_grad})[0];
    }
    return new_grad;
  }

  // Given a variable with its current grad as variable_grad, accumulates
  // new_grad into variable_grad if in place accumulation is possible.
  // Otherwise, uses 'update_grad' to update the grad for the variable.

  // "Gradient Layout Contract"
  //
  // AccumulateGrad tries to stash strided (non-sparse) grads with memory layout
  // (strides) such that variables and grads interact efficiently in later
  // optimizer kernels, and grads interact efficiently with c10d::Reducer.cpp.
  //
  // Specifically, AccumulateGrad tries to ensure the following
  // (cf torch/csrc/autograd/utils/grad_layout_contract.h):
  //   (1) if variable.is_non_overlapping_and_dense(), the stashed grad's
  //       strides match variable.
  //   (2) else, stashed grad is rowmajor contiguous.
  // If variable's grad does not exist (!variable_grad.defined())
  // AccumulateGrad steals new_grad if it's stealable and obeys the contract
  // already, otherwise it deep copies new_grad into an obedient clone.
  //
  // If variable's grad already exists (variable_grad.defined()), new_grad must
  // be added to variable_grad.  If we aren't setting up for double backward
  // (!GradMode::is_enabled()), AccumulateGrad performs "variable_grad += new_grad"
  // in-place, which keeps variable_grad's layout. We assume (hope) variable_grad
  // was created obeying (1) or (2) at some point in the past.
  //
  // If we are setting up for double backward, AccumulateGrad updates the grad
  // out-of-place via "variable_grad + new_grad."  TensorIterator operator+ decides
  // result's layout.  Typically TensorIterator matches strides of the first arg,
  // so we once again assume (hope) variable_grad was originally created obeying
  // (1) or (2).
  //
  // AccumulateGrad does not enforce the contract with 100% certainty.  Examples:
  //  - If a user manually permutes a param or its grad, then runs a fwd+bwd,
  //    variable_grad += new_grad keeps variable_grad's layout without rechecking
  //    the contract.
  //  - If TensorIterator changes its corner cases about operator+'s result
  //    (for example, giving more or less priority to channels_last inputs, see
  //    https://github.com/pytorch/pytorch/pull/37968) the result may not obey.
  //
  // Fortunately, if a given grad doesn't satisfy (1) or (2), the penalty is
  // degraded performance in Reducer.cpp or optimizer kernels, not death by
  // assert or silently bad numerics.

  // variable: the variable whose grad we're accumulating.
  // variable_grad: the current grad for the variable.
  // new_grad: new grad we want to acummulate for the variable.
  // num_expected_refs: the number of refs we expect to hold internally
  //                    such that it is safe to avoid cloning the grad
  //                    if use_count() of the grad is less than or equal
  //                    to this value (in addition to post_hooks).
  // update_grad: Function that is used to update grad for the variable.
  //              The argument to the function is a Tensor which
  //              is used to set a new value for the grad.
  template <typename T>
  static void accumulateGrad( // 這裡會進行具體的累積梯度
      const Variable& variable,
      at::Tensor& variable_grad,
      const at::Tensor& new_grad,
      size_t num_expected_refs,
      const T& update_grad) {
    if (!variable_grad.defined()) {
      if (!GradMode::is_enabled() &&
          !new_grad.is_sparse() &&
          new_grad.use_count() <= num_expected_refs &&
          (new_grad.is_mkldnn() || utils::obeys_layout_contract(new_grad, variable))) {
        // we aren't setting up for double-backward
        // not sparse
        // no other user-visible tensor references new_grad
        // new_grad obeys the "Gradient Layout Contract", there has a special case,
        // For MKLDNN tensor, which is a opaque tensor, assuming it obeys layout_contract.
        // Under these conditions, we can steal new_grad without a deep copy.
        update_grad(new_grad.detach());
      } else if (
          !GradMode::is_enabled() && new_grad.is_sparse() &&
          new_grad._indices().is_contiguous() &&
          new_grad._values().is_contiguous() &&
          // Use count for indices and values should always be <=1 since the
          // SparseTensor should be the only one holding a reference to these.
          new_grad._indices().use_count() <= 1 &&
          new_grad._values().use_count() <= 1 &&
          new_grad.use_count() <= num_expected_refs) {
        // Can't detach sparse tensor (since metadata changes are not allowed
        // after detach), so just create a new one for the grad which is a
        // shallow copy. We need a shallow copy so that modifying the original
        // grad tensor doesn't modify the grad we accumulate.
        // We only skip clone if indices and values themselves are contiguous
        // for backward compatiblity reasons. Since without this optimization,
        // earlier we would clone the entire SparseTensor which cloned indices
        // and values.
        // For details see https://github.com/pytorch/pytorch/issues/34375.
        update_grad(at::_sparse_coo_tensor_unsafe(
            new_grad._indices(),
            new_grad._values(),
            new_grad.sizes(),
            new_grad.options()));
      } else {
        if (new_grad.is_sparse()) {
          update_grad(new_grad.clone());
        } else {
          if (new_grad.is_mkldnn()) {
            update_grad(new_grad.clone());
          } else {
            // Deep copies new_grad according to the "Gradient Layout Contract."
            update_grad(utils::clone_obey_contract(new_grad, variable));
          }
        }
      }
    } else if (!GradMode::is_enabled()) {
      // This case is not strictly necessary, but it makes the first-order only
      // case slightly more efficient.
      if (variable_grad.is_sparse() && !new_grad.is_sparse()) {
        // If `variable_grad` is sparse and `new_grad` is not sparse, their
        // sum is not sparse, and we must change the TensorImpl type of
        // `variable_grad` for it to store the result. However, changing the
        // TensorImpl type of a tensor requires changing the tensor itself, and
        // thus in this case we have to change the grad tensor.
        auto result = new_grad + variable_grad;
        CHECK_RESULT(result, variable);
        update_grad(std::move(result));
      } else if (!at::inplaceIsVmapCompatible(variable_grad, new_grad)) {
        // Ideally we'd perform an in-place operation to avoid changing
        // the grad tensor. However, if that's impossible because the grads
        // are vmap-incompatible (See NOTE: [vmap-incompatible in-place operations]),
        // then we just add them out-of-place.
        auto result = variable_grad + new_grad;
        CHECK_RESULT(result, variable);
        update_grad(std::move(result));
      } else {
        // In this case we can avoid changing the grad tensor. There are three
        // scenarios when we'll hit this case:
        //
        // 1. `variable_grad` is sparse, and `new_grad` is sparse.
        // 2. `variable_grad` is dense, and `new_grad` is sparse.
        // 3. `variable_grad` is dense, and `new_grad` is dense.
        // 4. `variable_grad` is mkldnn, and `new_grad` is mkldnn.
        //
        // In all of these four cases, `variable_grad += new_grad` is a
        // valid operation which adds `new_grad` to `variable_grad` in
        // place. `variable_grad` is thus still referring to the same tensor
        // after the operation.
        // Also DistributedDataParallel(DDP) package relies on grad being
        // mutated in place for saving peak memory usage. DDP will still
        // work correctly if it is mutated out of place here, but DDP will
        // maintain one extra copy of grad tensors in buffer and thus
        // increase peak memory usage.
        variable_grad += new_grad;
        CHECK_RESULT(variable_grad, variable);
        // ^ We could enforce the contract more aggressively here by writing:
        // if (variable_grad.is_sparse() || new_grad.is_sparse()) {
        //   variable_grad += new_grad;
        // } else if (obeys_layout_contract(variable_grad, variable)) {
        //   variable_grad += new_grad;
        // } else {
        //   result = at::empty_strided(variable.sizes(), variable.strides(),
        //                              variable.options().memory_format(c10::nullopt));
        //   update_grad(at::native::add_out(result, variable_grad, new_grad, 1.0);
        // }
        // However, that accumulation is sometimes in place and sometimes not,
        // which may break user code.
      }
    } else {
      at::Tensor result;
      if (variable_grad.is_sparse() && !new_grad.is_sparse()) {
        // CPU backend throws an error on sparse + dense, so prefer dense + sparse here.
        result = new_grad + variable_grad;
      } else {
        // Assumes operator+ result typically matches strides of first arg,
        // and hopes variable_grad was originally created obeying layout contract.
        result = variable_grad + new_grad;
      }
      CHECK_RESULT(result, variable);
      update_grad(std::move(result));
      // ^ We could enforce the contract more aggressively here by saying
      // if (obeys_layout_contract(new_grad, variable)) {
      //   update_grad(new_grad + variable_grad);
      // } else {
      //   update_grad(variable_grad + new_grad);
      // }
      // such that the stashed grad is likely to have the right strides if
      // either variable_grad or new_grad already has the right strides.
      // We could enforce the contract with certainty by saying
      // auto result = variable_grad + new_grad (or vice versa), checking result's
      // layout, and copying to an obedient clone if necessary before update_grad.
      // The copy would require another gmem pass.  We can't create empty result with
      // the right layout then add_out into it with a single kernel, because GradMode
      // is enabled in this branch, and add_out isn't differentiable.
      // Maybe more trouble than it's worth.
    }
  }

  Variable variable;
};

具體可以如下圖所示,左邊是資料結構,右面是演算法流程,右面的序號表示執行從上至下,執行過程之中會用到左邊的資料結構,演算法與資料結構的呼叫關係由橫向箭頭表示。

  1. 分散式引擎呼叫execute_graph_task_until_ready_queue_empty來執行具體的 GraphTask。
  2. Engine::evaluate_function 會呼叫 GraphTask 之中的 ExecInfo。
  3. 然後會訪問 GradCaptureHook,呼叫hook,hook 的 operator函式會呼叫到 autogradContext_->accumulateGrad。
  4. autogradContext_ 會執行 accumulateGrad,對 hook(DistAccumulateGradCaptureHook)之中儲存的 accumulateGrad_ 做操作。
  5. AccumulateGrad::accumulateGrad 會完成最終的梯度更新操作。
                                     DATA STRUCTURE   +  ALGORITHM
                                                      |
+-----------------------------------------------+     |
| GraphTask                                     |     |  DistEngine::execute_graph_task_until_ready_queue_empty
|                                               |     |      +                |
|   unordered_map<Node*, ExecInfo> exec_info_   |     |      |                |
|                            +                  | <----------+                |
|                            |                  |     |                       |
+-----------------------------------------------+     |                       | 1
                             |                        |                       |
                             |                        |                       |
                             v                        |                       |
       +---------------------+------------------+     |                       v
       | ExecInfo                               | <-------------+  Engine::evaluate_function
       |                                        |     |                       +
       |       < vector<Capture> > captures_    |     |                       |
       |                   +                    |     |                       |
       |                   |                    |     |                       | 2
       +----------------------------------------+     |                       |
                           |                          |                       v
                           |                          |
                           v                          |      +--+ captured_grad = (*hook)(captured_grad)
       +-------------------+--------------------+     |      |                +
       | Capture                                |     |      |                |
       |                                        |     |      |                |
       |   vector< <GradCaptureHook> > hooks_ <--------------+                | 3
       |                   +                    |     |                       |
       +----------------------------------------+     |                       v
                           |                          |
                           |                          |   +--+ autogradContext_->accumulateGrad(
                           v                          |   |         accumulateGrad_-> variable, inputGrads[0], 3)
       +-------------------+--------------------+     |   |                   +
       | DistAccumulateGradCaptureHook          |     |   |                   |
       |                                        |     |   |                   |
       |      ContextPtr autogradContext_    <------------+                   | 4
       |                                        |     |   |                   |
       |      AccumulateGrad accumulateGrad_ <------------+                   v
       |                          +             |     |
       +----------------------------------------+     |   +-+ new_grad = AccumulateGrad::callHooks(variable, grad)
                                  |                   |   |                   +
                                  |                   |   |                   |
                                  v                   |   |                   | 5
              +-------------------+------+            |   |                   v
              | AccumulateGrad           |            |   |
              |                          |            |   |      AccumulateGrad::accumulateGrad(
              |      Variable variable <------------------+------+   variable, old_grad, new_grad,)
              |                          |            |
              +--------------------------+            +

手機如下:

0x05 等待完成

最後,分散式引擎會呼叫 clearAndWaitForOutstandingRpcsAsync 來等待處理完成。

c10::intrusive_ptr<c10::ivalue::Future> DistAutogradContext::
    clearAndWaitForOutstandingRpcsAsync() {
  std::unique_lock<std::mutex> lock(lock_);
  auto outStandingRpcs = std::move(outStandingRpcs_);
  lock.unlock();

  struct State {
    explicit State(int32_t count)
        : future(
              c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get())),
          remaining(count) {}
    c10::intrusive_ptr<c10::ivalue::Future> future;
    std::atomic<int32_t> remaining;
    std::atomic<bool> alreadySentError{false};
  };
  auto state = std::make_shared<State>(outStandingRpcs.size());
  if (outStandingRpcs.empty()) {
    state->future->markCompleted(c10::IValue());
  } else {
    for (auto& rpc : outStandingRpcs) {
      rpc->addCallback([state](rpc::JitFuture& future) {
        if (future.hasError()) {
          // If there's an error, we want to setError() on the future,
          // unless another error has already been sent - use a CAS to
          // guard.
          //
          // Don't decrement num remaining here! (We don't need to, since
          // memory handling is separate). If we simply don't decrement on
          // errors, reaching 0 means that there were no errors - and hence,
          // we can just markCompleted() without any other checking there.
          bool expectedAlreadySent = false;
          if (state->alreadySentError.compare_exchange_strong(
                  expectedAlreadySent, true)) {
            state->future->setError(future.exception_ptr());
          }
          return;
        }

        if (--state->remaining == 0) {
          state->future->markCompleted(c10::IValue());
        }
      });
    }
  }
  return state->future;
}

支援,分散式 autograd 全部分析完畢,前面說過,分散式處理有四大金剛,我們簡介了 RPC,RRef,分析了分散式引擎,從下一篇開始,我們開始分析剩下的分散式優化器,此係列可能包括4~6篇。

0xFF 參考

Distributed Autograd Design

Remote Reference Protocol

PyTorch 原始碼解讀之分散式訓練了解一下?

https://pytorch.org/docs/stable/distributed.html

https://pytorch.apachecn.org/docs/1.7/59.html

https://pytorch.org/docs/stable/distributed.html#module-torch.distributed

https://pytorch.org/docs/master/notes/autograd.html

https://pytorch.org/docs/master/rpc/distributed_autograd.html
https://pytorch.org/docs/master/rpc/rpc.html

https://www.w3cschool.cn/pytorch/pytorch-cdva3buf.html

PyTorch 分散式 Autograd 設計

Getting started with Distributed RPC Framework

Implementing a Parameter Server using Distributed RPC Framework

Combining Distributed DataParallel with Distributed RPC Framework

Profiling RPC-based Workloads

Implementing batch RPC processing

Distributed Pipeline Parallel

相關文章