[原始碼解析] PyTorch 分散式 Autograd (3) ---- 上下文相關

羅西的思考發表於2021-12-01

[原始碼解析] PyTorch 分散式 Autograd (3) ---- 上下文相關

0x00 摘要

我們已經知道 dist.autograd 如何傳送和接受訊息,本文再來看看如何其他支撐部分,就是如何把傳送接受兩個動作協調起來,如何確定每個傳送/接受節點,如何確定每一個訊息互動Session。

通過本文大家可以瞭解:AutogradMetadata 用來在不同節點間傳遞 autograd 元資訊,DistAutogradContext 代表一個分散式autograd 相關資訊,DistAutogradContainer 負責在一個worker之上儲存 DistAutogradContext。

PyTorch分散式其他文章如下:

深度學習利器之自動微分(1)

深度學習利器之自動微分(2)

[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀

[原始碼解析]PyTorch如何實現前向傳播(1) --- 基礎類(上)

[原始碼解析]PyTorch如何實現前向傳播(2) --- 基礎類(下)

[原始碼解析] PyTorch如何實現前向傳播(3) --- 具體實現

[原始碼解析] Pytorch 如何實現後向傳播 (1)---- 呼叫引擎

[原始碼解析] Pytorch 如何實現後向傳播 (2)---- 引擎靜態結構

[原始碼解析] Pytorch 如何實現後向傳播 (3)---- 引擎動態邏輯

[原始碼解析] PyTorch 如何實現後向傳播 (4)---- 具體演算法

[原始碼解析] PyTorch 分散式(1)------歷史和概述

[原始碼解析] PyTorch 分散式(2) ----- DataParallel(上)

[原始碼解析] PyTorch 分散式(3) ----- DataParallel(下)

[原始碼解析] PyTorch 分散式(4)------分散式應用基礎概念

[原始碼解析] PyTorch分散式(5) ------ DistributedDataParallel 總述&如何使用

[原始碼解析] PyTorch分散式(6) ---DistributedDataParallel -- 初始化&store

[原始碼解析] PyTorch 分散式(7) ----- DistributedDataParallel 之程式組

[原始碼解析] PyTorch 分散式(8) -------- DistributedDataParallel之論文篇

[原始碼解析] PyTorch 分散式(9) ----- DistributedDataParallel 之初始化

[原始碼解析] PyTorch 分散式(10)------DistributedDataParallel 之 Reducer靜態架構

[原始碼解析] PyTorch 分散式(11) ----- DistributedDataParallel 之 構建Reducer和Join操作

[原始碼解析] PyTorch 分散式(12) ----- DistributedDataParallel 之 前向傳播

[原始碼解析] PyTorch 分散式(13) ----- DistributedDataParallel 之 反向傳播

[原始碼解析] PyTorch 分散式 Autograd (1) ---- 設計

[原始碼解析] PyTorch 分散式 Autograd (2) ---- RPC基礎

為了更好的說明,本文程式碼會依據具體情況來進行相應精簡。

0x01 設計脈絡

1.1 前文回顧

在前文之中當傳送訊息時候,我們在 sendMessageWithAutograd 通過 getMessageWithAutograd 來獲得了 FORWARD_AUTOGRAD_REQ 型別的訊息。

c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
    RpcAgent& agent,
    const WorkerInfo& dst,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    bool forceGradRecording,
    const float rpcTimeoutSeconds,
    bool forceDisableProfiling) {
    
  auto msg = getMessageWithAutograd( // 這裡會與上下文互動,構建了 FORWARD_AUTOGRAD_REQ
      dst.id_,
      std::move(wrappedRpcMsg),
      MessageType::FORWARD_AUTOGRAD_REQ,
      forceGradRecording,
      agent.getDeviceMap(dst));

  c10::intrusive_ptr<JitFuture> fut;
  if (!forceDisableProfiling && torch::autograd::profiler::profilerEnabled()) {
    auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
    auto msgWithProfiling = getMessageWithProfiling(
        std::move(msg),
        rpc::MessageType::RUN_WITH_PROFILING_REQ, //構建訊息
        std::move(profilerConfig));
    // 傳送訊息
    fut = agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
  } else {
    // 傳送訊息
    fut = agent.send(dst, std::move(msg), rpcTimeoutSeconds);
  }

  return fut;
}

而 getMessageWithAutograd 會與上下文互動,其程式碼位於 torch/csrc/distributed/autograd/utils.cpp。

Message getMessageWithAutograd(
    const rpc::worker_id_t dstId,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    MessageType msgType,
    bool forceGradRecording,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  
  // 獲取到 DistAutogradContainer
  auto& autogradContainer = DistAutogradContainer::getInstance();

  // If there is no valid context and no tensor requires grads, send original
  // rpc message. otherwise, attach grad info and grad functions and send
  // rpcWithAutograd message.
  auto tensorsRequireGrad =
      torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
  if (!autogradContainer.hasValidContext() ||
      (!forceGradRecording && !tensorsRequireGrad)) {
    return std::move(wrappedRpcMsg);
  }

  // Retrieve the appropriate context to modify.
  auto autogradContext = autogradContainer.currentContext(); // 獲取到上下文,每個worker都有自己的上下文

  // Wrap the original rpc with autograd information.
  // newAutogradMessageId 會生成一個messageID
  AutogradMetadata autogradMetadata( // 構建了 AutogradMetadata
      autogradContext->contextId(), autogradContainer.newAutogradMessageId());
  auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
      RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
      msgType,
      autogradMetadata,
      std::move(wrappedRpcMsg),
      deviceMap);

  if (tensorsRequireGrad) {
    // Record autograd information for 'send'.
    addSendRpcBackward( // 這裡把本地上下文,autograd 的元資訊等一起打包
        autogradContext, autogradMetadata, rpcWithAutograd->tensors());
  }
  // Record the workerID
  autogradContext->addKnownWorkerId(dstId);

  return std::move(*rpcWithAutograd).toMessage(); // 最終構建了一個message
}

因此,就引出了AutogradMetadata,DistAutogradContainer 和 DistAutogradContext 等一系列基礎類,我們接下來就仔細分析一下。

1.2 總體思路

我們概括一下總體思路。

先看看問題:假如一套系統包括 a,b,c 三個節點,每個節點執行一個 worker,那麼當執行一個傳播操作,我們涉及到在這三個節點之間互相傳播。因此我們需要一個機制,來在這三個節點之中唯一標示這個傳播過程,在這個傳播過程之中,也要在每一個節點之上把每一個send/recv都標示出來,這樣才能讓節點可以支援多個操作並行

再看看解決方案:

  • 使用上下文來唯一標示一個傳播過程。DistAutogradContext 儲存在一個worker之上的每一個分散式autograd的相關資訊,其在分散式 autograd 之中封裝前向和後向傳播,累積梯度,這避免了多個worker在彼此的梯度上互相影響。每個自動微分過程被賦予一個唯一的 autograd_context_id,在容器中,這個微分過程的上下文(DistAutogradContext) 依據這個autograd_context_id 來唯一確認。
  • 使用autogradMessageId 來表示一對 send/recv autograd 函式。每send-recv對被分配一個全域性唯一的autograd_message_id 以唯一地標識該send-recv對。這對於在向後傳播期間查詢遠端節點上的相應函式很有用。
  • 最後,每個worker需要有一個地方來保持上下文和messageid,所以有了DistAutogradContainer這個類。每個worker擁有唯一一個單例DistAutogradContainer,其負責:
    • 對於每一個自動微分過程儲存其分散式上下文。
    • 一旦這個自動微分過程結束,就清除其資料。

0x02 AutogradMetadata

2.1 定義

AutogradMetadata 這個類是用來在不同節點之間傳遞 autograd 的元資訊,就是把上下文等資訊封裝了一下。即,傳送方通知接收方自己的上下文資訊,接收方會依據收到的這些上下文資訊作相應處理。

我們提前劇透,接收方會使用 autogradContextId 和 autogradMessageId 分別作為 上下文 和 訊息 的唯一標示。從註釋之中可以知道。

  • autogradContextId 是全域性唯一整數,用來表示一個唯一的分散式 autograd 傳播過程(包括前向傳播和後向傳播)。一個傳播過程會包括在反向傳播鏈條上的多對send/recv autograd 函式。
  • autogradMessageId 是全域性唯一整數,用來表示一對 send/recv autograd 函式。每send-recv對被分配一個全域性唯一的autograd_message_id 以唯一地標識該send-recv對。這對於在向後傳播期間查詢遠端節點上的相應函式很有用。
// This structure represents autograd metadata that we need to pass across
// different nodes when we call an RPC which needs autograd computation.
struct TORCH_API AutogradMetadata {
  AutogradMetadata(int64_t autogradContextId, int64_t autogradMessageId);

  // autogradContextId_ is a globally unique integer that identifies a
  // particular distributed autograd pass.
  int64_t autogradContextId;
  // autogradMessageId_ is a globally unique integer that identifies a pair
  // of send/recv autograd functions.
  int64_t autogradMessageId;
};

那麼問題來了,autogradContextId 和 autogradMessageId 分別怎麼做到全域性(包括多個節點)唯一呢?

2.2 autogradMessageId

我們先概括一下:autogradMessageId 是由 rank 間接生成的,然後在內部進行遞增,所以可以保證全域性唯一。

我們從後往前推導。

  • 先看 newAutogradMessageId 是如何生成訊息 id,原來是在 DistAutogradContainer 之中的成員變數 next_autograd_message_id_ 遞增得到。
int64_t DistAutogradContainer::newAutogradMessageId() {
  // Check for overflow into workerId_ section.
  TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
  return next_autograd_message_id_++;
}
  • 然後看如何初始化 next_autograd_message_id_?從 DistAutogradContainer 的 init 函式中可以知道,原來是依據 worker_id 來生成 next_autograd_message_id_。work_id 是 init 函式所得到的引數。
DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
  std::lock_guard<std::mutex> guard(dist_container_init_lock_);

  auto& container = getInstanceInternal();
  container.worker_id_ = worker_id;
  container.next_context_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.max_id_ =
      (kAutoIncrementMask |
       (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
  container.initialized_ = true;
  return container;
}
  • 我們再推導,看看如何設定 worker id,找到了如下,看來需要看看 python 世界的 _init 方法。
module.def(
    "_init",
    [](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
    py::call_guard<py::gil_scoped_release>());

來到 python 世界,可以看到,使用了 rank 來作為引數,而 rank 是每個 worker 唯一的,這樣就保證了 worker ID 唯一,從而 訊息 id 唯一。

    def init_rpc(
        name,
        backend=None,
        rank=-1,
        world_size=None,
        rpc_backend_options=None,
    ):
			dist_autograd._init(rank) # rank是全域性唯一

我們把這些邏輯關係總結下來:

worker_id = rank;

container.worker_id_ = worker_id;

container.next_autograd_message_id_ = static_cast<int64_t>(worker_id) << kAutoIncrementBits

然後 next_autograd_message_id_ 內部遞增。

int64_t DistAutogradContainer::newAutogradMessageId() {
  // Check for overflow into workerId_ section.
  TORCH_INTERNAL_ASSERT(next_autograd_message_id_ < max_id_);
  return next_autograd_message_id_++;
}

所以,AutogradMessageId 是全域性唯一的。我們用圖例來看看:

+----------------------------------------------------------------------------------------+
| worker                                                                                 |
|                       +-------------------------------------+                          |
|                       | DistAutogradContainer               |                          |
|                       |                                     |                          |
|                       |                                     |                          |
|              init()   |                                     |                          |
|      rank +--------------+----> worker_id_                  |                          |
|                1      |  |                                  |   newAutogradMessageId() |
|                       |  +----> next_autograd_message_id_+------------------+          |
|                       |                                     |          2    |          |
|                       +-------------------------------------+               |          |
|                                                                             |          |
|                                                                             |          |
|                                                                             |          |
|                                                                             |          |
|                     +---------------------------------------------------------------+  |
|                     | getMessageWithAutograd                                |       |  |
|                     |                                                       |       |  |
|                     |                                                       v       |  |
|                     |                                                               |  |
|                     |   AutogradMetadata autogradMetadata(contextId(), MessageId()) |  |
|                     |                           4                           3       |  |
|                     |                                                               |  |
|                     +---------------------------------------------------------------+  |
|                                                                                        |
+----------------------------------------------------------------------------------------+

為了看看 autogradContextId 為什麼可以保證唯一,我們需要先分析 DistAutogradContainer 和 DistAutogradContext。

0x03 DistAutogradContainer

每個worker擁有唯一一個單例DistAutogradContainer,其負責:

  • 對於每一個自動微分過程儲存其分散式上下文。
  • 一旦這個自動微分過程結束,就清除其資料。

每個自動微分過程被賦予一個唯一的 autograd_context_id。在每個容器中,這個微分過程的上下文(DistAutogradContext) 依據這個autograd_context_id 來唯一確認。autograd_context_id 是一個 64 bit 的全域性唯一id,前 16 bis 是 worker_id,後 48 位是在每個worker內部自動遞增id。所以可見,一個Container 之中,是有多個Context的。

此容器還負責維護全域性唯一的訊息id,用來關聯傳送/接收自動微分函式對。格式類似於autograd_context_id,是一個64位整數,前16位是工作者id,後48位是worker內部自動遞增的。

因為訊息 id 和 上下文 id 的前16 位是 worker_id,也就是 rank id,再加上後48位內部自增,所以可以保證 訊息 id 和 上下文 id 全域性唯一

3.1 定義

DistAutogradContainer 定義如下,其中:

  • worker_id_ : 本 worker 的 ID,其實就是本 worker 的 rank。
  • next_context_id_ :自增的上下文ID,用來給每個自動微分過程賦予一個唯一的autograd_context_id。在一個傳播鏈條上,其實只有第一個節點的 DistAutogradContainer 用到了 next_context_id_ 來生成 Context,後續節點的 DistAutogradContainer 都是依據第一個 DistAutogradContainer 的 context id 資訊來在本地生成對應 context id 的 Context。
  • next_autograd_message_id_ :維護全域性唯一的訊息id,用來關聯 傳送/接收 自動微分函式對。此變數是在本節點傳送時候會使用到。
// Singleton class per worker which is responsible for storing the distributed
// autograd context for each autograd pass and also cleans up data for an
// autograd pass once its done.
//
// Each autograd pass is assigned a unique autograd_context_id and all data for
// that pass (DistAutogradContext) is stored in this container indexed by the
// autograd_context_id. The autograd_context_id itself is a 64 bit globally
// unique id. The first 16 bits is the worker_id and the next 48 bits is an
// auto-incrementing id for each worker.
//
// This container is also responsible for maintaining a globally unique message
// id, which is used to associate send/recv autograd function pairs. The format
// is similar to the autograd_context_id where we have a 64 bit integer with
// first 16 bits being the worker id and next 48 bits are auto-incrementing.
class TORCH_API DistAutogradContainer {

 private:
  // Number of shards for the map storing autograd contexts. We'd like this
  // to be a power of 2 and we don't expect a value much higher than the
  // number of cores would provide much benefit.
  static constexpr uint32_t kNumDefaultShards = 128;

  // Use cache line size for alignment.
  static constexpr int kCacheLineSize = 64;

  // Structure holding one shard of the sharded autograd context map with its
  // associated lock. Align to cache line size to avoid contention between
  // adjacent entries.
  struct alignas(kCacheLineSize) ContextsShard {
    // Lock for this shard.
    mutable std::mutex lock;

    // Map storing autograd contexts for this shard.
    std::unordered_map<int64_t, ContextPtr> contexts; // 這裡儲存了上下文指標
  };

  // Auto incrementing context id used to identify unique autograd passes.
  // Initialized with the first 16 bits being the worker_id.
  std::atomic<int64_t> next_context_id_; // 新增上下文id

  // Unique id to identify a worker in the distributed setting.
  int16_t worker_id_;

  // Whether or not the container has been initialized appropriately.
  bool initialized_;

  // Sharded autograd context map.
  std::vector<ContextsShard> autograd_contexts_; // 儲存上下文列表

  // Number of shards for the sharded autograd_contexts_ map.
  uint32_t num_shards_;

  // Autograd message id to identify unique send/recv autograd function pairs.
  std::atomic<int64_t> next_autograd_message_id_;

  // Maximum allowed value for autograd_context_id or autograd_message_id.
  int64_t max_id_;
};

3.2 構建

Init 方法構建了 DistAutogradContainer,主要就是利用 worker_id 對本地成員變數進行相關賦值。

DistAutogradContainer& DistAutogradContainer::init(int64_t worker_id) {
  std::lock_guard<std::mutex> guard(dist_container_init_lock_);

  TORCH_CHECK(
      worker_id >= 0 && worker_id <= kMaxWorkerId,
      "worker_id needs to be in the range [0, 65535]")

  auto& container = getInstanceInternal();
  TORCH_CHECK(
      !container.initialized_ || (worker_id == container.worker_id_),
      "Container is already initialized with worker_id: ",
      container.worker_id_,
      ", cannot initialize with different worker_id: ",
      worker_id);

  if (container.initialized_) {
    return container;
  }

  container.worker_id_ = worker_id;
  container.next_context_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.next_autograd_message_id_ = static_cast<int64_t>(worker_id)
      << kAutoIncrementBits;
  container.max_id_ =
      (kAutoIncrementMask |
       (static_cast<int64_t>(worker_id) << kAutoIncrementBits));
  container.initialized_ = true;
  return container;
}

0x04 DistAutogradContext

DistAutogradContext 儲存在一個worker之上的每一個分散式autograd的相關資訊,其在分散式 autograd 之中封裝前向和後向傳播,累積梯度,這避免了多個worker在彼此的梯度上互相影響。

由前面可知道,contextId_ 是全域性唯一。

4.1 定義

這裡僅僅給出 DistAutogradContext 成員變數,忽略其成員函式。其中成員變數最主要的有三個:

  • contextId_ 是上下文 id。
  • sendAutogradFunctions_ 是一個 map 型別變數,會收集所有傳送請求對應的反向傳播運算元 SendRpcBackward。
  • recvAutogradFunctions_ 是一個 map 型別變數,會收集所有接受送請求對應的反向傳播運算元 RecvRpcBackward。

關於 SendRpcBackward 和 RecvRpcBackward,我們後續會結合引擎進行分析。

// DistAutogradContext which stores information for a single distributed
// autograd pass on a worker.
class TORCH_API DistAutogradContext {
 private:
  friend class BackwardPassCleanupGuard;
  friend class DistEngine;
  friend class RecvRpcBackward;
  friend class DistAccumulateGradCaptureHook;

  const int64_t contextId_;

  // Set containing known worker IDs, used in cleaning up autograd context.
  // Whenever a sendRpcBackward is attached to the autograd graph for this
  // context, the destination is added here.
  std::unordered_set<rpc::worker_id_t> knownWorkerIds_;

  // Map from autograd_message_id to appropriate 'send' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
      sendAutogradFunctions_;

  // Map from autograd_message_id to appropriate 'recv' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
      recvAutogradFunctions_;

  // Gradients accumulated in this context so far. The key is the variable on
  // which the gradient needs to be accumulated and the value is the gradient
  // that needs to be accumulated on that variable..
  c10::Dict<torch::Tensor, torch::Tensor> accumulatedGrads_;

  // See comments for recordGradEvent(c10::Device device);
  std::unordered_map<c10::Device, c10::Event> gradReadyEvents_;
  const c10::impl::VirtualGuardImpl impl_;

  // The autograd GraphTask for the backward pass on this node for this context.
  std::shared_ptr<torch::autograd::GraphTask> graphTask_;

  // List of futures for RPCs initiated by this node to propagate gradients to
  // other nodes. The distributed autograd engine on this node can return
  // successfully only if all these futures are done and are successful.
  std::vector<c10::intrusive_ptr<rpc::JitFuture>> outStandingRpcs_;

  // Lock to protect concurrent modification of the context.
  mutable std::mutex lock_;
};

4.2 訊息

上下文主要包括幾種訊息型別,比如:

// Messages with autograd info
FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,

// Messages to propagate gradients on the backward pass.
BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,

4.3 構建

我們首先看看如何構建上下文。

4.3.1 getOrCreateContext

getOrCreateContext 函式是用來得到上下文,如果已經有,就直接獲取,如果沒有,就新構建一個。這是一個被動呼叫,recv 端會用到這個

ContextPtr DistAutogradContainer::getOrCreateContext(int64_t context_id) {
  auto& shard = getShard(context_id);
  std::lock_guard<std::mutex> guard(shard.lock);
  auto it = shard.contexts.find(context_id); // 根據這個context id來查詢
  if (it != shard.contexts.end()) {
    return it->second; // 找到就返回
  }

  auto& context = // 如果沒有,就構建一個 context
      shard.contexts
          .emplace(
              std::piecewise_construct,
              std::forward_as_tuple(context_id),
              std::forward_as_tuple(
                  std::make_shared<DistAutogradContext>(context_id)))
          .first->second;
  return context;
}

4.3.2 newContext

這裡是主動呼叫,send 端會呼叫這個方法

4.3.2.1 Python

當分散式呼叫時候,python世界會生成一個context。

            with dist_autograd.context() as context_id:
                output = model(indices, offsets)
                loss = criterion(output, target)

                # Run distributed backward pass
                dist_autograd.backward(context_id, [loss])

                # Run distributed optimizer. Gradients propagated all the way to the parameter servers
                opt.step(context_id)

當生成時,__enter__ 會呼叫 _new_context() 在C++生成一個context。

class context(object):
    '''
    Context object to wrap forward and backward passes when using
    distributed autograd. The ``context_id`` generated in the ``with``
    statement  is required to uniquely identify a distributed backward pass
    on all workers. Each worker stores metadata associated with this
    ``context_id``, which is required to correctly execute a distributed
    autograd pass.

    Example::
        >>> import torch.distributed.autograd as dist_autograd
        >>> with dist_autograd.context() as context_id:
        >>>   t1 = torch.rand((3, 3), requires_grad=True)
        >>>   t2 = torch.rand((3, 3), requires_grad=True)
        >>>   loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum()
        >>>   dist_autograd.backward(context_id, [loss])
    '''
    def __enter__(self):
        self.autograd_context = _new_context() # 這裡生成一個上下文
        return self.autograd_context._context_id()

    def __exit__(self, type, value, traceback):
        _release_context(self.autograd_context._context_id())

具體通過如下對映,我們可以看到 C++ 世界之中對應的方法,呼叫到了 DistAutogradContainer::getInstance().newContext()。

  module.def(
      "_new_context",
      []() -> const ContextPtr {
        return DistAutogradContainer::getInstance().newContext();
      },
      py::return_value_policy::reference);
4.3.2.2 C++

我們來到了C++世界。每一個執行緒都有一個autograd_context_id。

constexpr int64_t kInvalidContextId = -1;

// Each thread has a single autograd_context_id valid at any point in time.
static thread_local int64_t current_context_id_ = kInvalidContextId;

newContext 就是生成了一個DistAutogradContext,其中通過 Container 的成員變數 next_context_id_ 的遞增來指定下一個上下文的id。

const ContextPtr DistAutogradContainer::newContext() {

  auto context_id = next_context_id_++; // 遞增
  current_context_id_ = context_id;  // 在這裡設定了本地執行緒的 current_context_id_

  // Check for overflow into workerId_ section.
  TORCH_INTERNAL_ASSERT(context_id < max_id_);

  auto& shard = getShard(context_id);
  std::lock_guard<std::mutex> guard(shard.lock);
  auto& context =
      shard.contexts
          .emplace(
              std::piecewise_construct,
              std::forward_as_tuple(context_id),
              std::forward_as_tuple(
                  std::make_shared<DistAutogradContext>(context_id)))
          .first->second;

  return context;
}

4.4 如何共享上下文

具體使用中,在with語句中生成的context_id可以用作在所有 worker 之上唯一標識一個分散式後向傳播(包括前向傳播和後向傳播)。每個worker儲存與此 context_id關聯的後設資料,這是正確執行分散式自動載入過程所必需的。

因為需要在多個 worker 之中都儲存這個 context_id關聯的後設資料,所以就需要一個 封裝/傳送/接受的機制來在 worker 之間傳遞這個後設資料,封裝機制就是我們前面提到的 AutogradMetadata。我們接下來看看如何傳送/接受上下文元資訊

4.4.1 傳送方

當傳送訊息時候,getMessageWithAutograd 會使用 autogradContainer.currentContext() 獲取當前上下文,進行傳送。

Message getMessageWithAutograd(
    const rpc::worker_id_t dstId,
    torch::distributed::rpc::Message&& wrappedRpcMsg,
    MessageType msgType,
    bool forceGradRecording,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  auto& autogradContainer = DistAutogradContainer::getInstance();

  // If there is no valid context and no tensor requires grads, send original
  // rpc message. otherwise, attach grad info and grad functions and send
  // rpcWithAutograd message.
  auto tensorsRequireGrad =
      torch::autograd::compute_requires_grad(wrappedRpcMsg.tensors());
  if (!autogradContainer.hasValidContext() ||
      (!forceGradRecording && !tensorsRequireGrad)) {
    return std::move(wrappedRpcMsg);
  }

  // Retrieve the appropriate context to modify.
  auto autogradContext = autogradContainer.currentContext(); // 獲取當前上下文

  // Wrap the original rpc with autograd information.
  AutogradMetadata autogradMetadata( // 使用上下文id和訊息id來構建後設資料
      autogradContext->contextId(), autogradContainer.newAutogradMessageId());
  auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
      RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
      msgType,
      autogradMetadata,
      std::move(wrappedRpcMsg),
      deviceMap);

  if (tensorsRequireGrad) {
    // Record autograd information for 'send'.
    addSendRpcBackward(
        autogradContext, autogradMetadata, rpcWithAutograd->tensors());
  }
  // Record the workerID
  autogradContext->addKnownWorkerId(dstId);

  return std::move(*rpcWithAutograd).toMessage();
}

我們之前的圖現在可以擴充,加入了上下文ID。

+----------------------------------------------------------------------------------------+
| worker                                                                                 |
|                  +------------------------------------------+                          |
|                  |DistAutogradContainer                     |                          |
|          init()  |                                          |                          |
|  rank +-------------+----> worker_id_                       |                          |
|                  |  |                                       |                          |
|                  |  +----> next_context_id_+-------------+  |                          |
|                  |  |                                    |  |                          |
|                  |  +----> next_autograd_message_id_ +----------------------+          |
|                  |                                       |  |               |          |
|                  |                                       |  |               |          |
|                  +------------------------------------------+               |          |
|                                                          |                  |          |
|                                                          |                  |          |
|                                                          |                  |          |
|                  +------------------------------------------------------------------+  |
|                  |getMessageWithAutograd                 |                  |       |  |
|                  |                                       |                  |       |  |
|                  |                                       v                  v       |  |
|                  |                                                                  |  |
|                  |    AutogradMetadata autogradMetadata(contextId(), MessageId())   |  |
|                  |                                                                  |  |
|                  |                                                                  |  |
|                  +------------------------------------------------------------------+  |
|                                                                                        |
+----------------------------------------------------------------------------------------+

addSendRpcBackward 就被傳入當前上下文之中,後續反向傳播時候,會取出這個 addSendRpcBackward。

void addSendRpcBackward(
    const ContextPtr& autogradContext,
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors) {
  // Attach autograd information only for tensors requiring grad.
  std::vector<torch::Tensor> tensors_with_grad;
  std::copy_if(
      tensors.begin(),
      tensors.end(),
      std::back_inserter(tensors_with_grad),
      [](const torch::Tensor& t) { return t.requires_grad(); });

  // Attach the appropriate autograd edges.
  auto grad_fn = std::make_shared<SendRpcBackward>();
  grad_fn->set_next_edges(
      torch::autograd::collect_next_edges(tensors_with_grad));

  // Add the appropriate input metadata for the grad_fn.
  for (const auto& tensor : tensors_with_grad) {
    grad_fn->add_input_metadata(tensor);
  }

  // Record the send autograd function in our current context.
  autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
}

4.4.2 接受方

在 addRecvRpcBackward 之中,會依據傳遞過來的 autogradMetadata.autogradContextId 來構建一個上下文。

ContextPtr addRecvRpcBackward(
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors,
    rpc::worker_id_t fromWorkerId,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  // Initialize autograd context if necessary.
  auto& autogradContainer = DistAutogradContainer::getInstance();
  // 生成或者得到一個上下文,把傳送方的 autogradContextId 傳入,即利用 autogradContextId 作為key後續可以查詢到這個上下文
  auto autogradContext = 
      autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);

  if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
    // Attach the tensors as inputs to the autograd function.
    auto grad_fn = std::make_shared<RecvRpcBackward>(
        autogradMetadata, autogradContext, fromWorkerId, deviceMap);
    for (auto& tensor : tensors) {
      if (tensor.requires_grad()) {
        torch::autograd::set_history(tensor, grad_fn);
      }
    }

    // Now update the autograd context with the necessary information.
    autogradContext->addRecvFunction(
        grad_fn, autogradMetadata.autogradMessageId);
  }

  return autogradContext;
}

這樣,傳送方和接收方就共享了一個上下文,而且這個上下文的id是全域性唯一的。

具體邏輯如下,上方是傳送端,下方是接收端。

  • 傳送端
    • 利用本地 context_id 構建了 AutogradMetadata,AutogradMetadata含有 ctx_id, msg_id。
    • 利用 AutogradMetadata 構建了 Message。
    • 利用 agent.send 傳送了 Message。
  • 接收端:
    • 收到了 Message。
    • 從 Message 之中解析出 AutogradMetadata。
    • 從 AutogradMetadata 提取出 context_id。
    • 利用 context_id 構建了本地的 DistAutogradContext。
  • 傳送方和接收方就共享了一個上下文(這個上下文的id是全域性唯一的)。
+----------------------------------------------------------------------------------+
| sendMessageWithAutograd                                                          |
|                                                                                  |
|  +----------------------------------------------------------------------------+  |
|  | addSendRpcBackward                                                         |  |
|  |                                                                            |  |
|  |                                                                            |  |
|  |               autogradMetadata = AutogradMetadata(context_id, message_id)  |  |
|  |                          +                                                 |  |
|  |                          |                                                 |  |
|  +----------------------------------------------------------------------------+  |
|                             |                                                    |
|                             v                                                    |
|        agent.send(message(autogradMetadata)                                      |
|                             +                                                    |
|                             |                                                    |
+----------------------------------------------------------------------------------+
                              |
                              |
                              |
                              |                                             Sender
+-----------------------------------------------------------------------------------+
                              |                                             Receiver
                              | message
                              v
                              |
+----------------------------------------------------------------------------------+
| processForwardAutogradReq   |                                                    |
|                             |                                                    |
|                             | message.autogradMetadata                           |
|                             v                                                    |
|  +----------------------------------------------------------------------------+  |
|  | addSendRpcBackward       |                                                 |  |
|  |                          |                                                 |  |
|  |                          +--------------------+                            |  |
|  |                                               |                            |  |
|  |                                               v                            |  |
|  |   autogradContext = getOrCreateContext(autogradMetadata.autogradContextId) |  |
|  |                                                                            |  |
|  |                                                                            |  |
|  +----------------------------------------------------------------------------+  |
|                                                                                  |
+----------------------------------------------------------------------------------+

0x05 前向傳播互動過程

前面的分享過程還是簡略,我們接下來把完整的傳送/接受過程詳細分析一下。

5.1 傳送

這裡對應設計中的如下文字:

在前向傳播期間,我們在上下文中儲存每個 autograd 傳播的sendrecv函式。這確保我們在 autograd 圖中儲存對適當節點的引用以使其保持活動狀態。除此之外,這也使得在後向傳播期間很容易查詢到對應的sendrecv函式。

5.1.1 傳送邏輯

程式碼邏輯如下:

  • 生成一個 grad_fn,其型別是 SendRpcBackward。
  • 呼叫 collect_next_edges 和 set_next_edges 為 SendRpcBackward 新增後續邊,這些函式我們在前面系列中有分析。
  • 呼叫 add_input_metadata 新增輸入後設資料。
  • 呼叫 addSendFunction 往上下文新增 grad_fn。
void addSendRpcBackward(
    const ContextPtr& autogradContext,
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors) {
  // Attach autograd information only for tensors requiring grad.
  std::vector<torch::Tensor> tensors_with_grad;
  std::copy_if(
      tensors.begin(),
      tensors.end(),
      std::back_inserter(tensors_with_grad),
      [](const torch::Tensor& t) { return t.requires_grad(); });

  // Attach the appropriate autograd edges.
  auto grad_fn = std::make_shared<SendRpcBackward>();
  grad_fn->set_next_edges( // 這裡會設定其輸出邊
      torch::autograd::collect_next_edges(tensors_with_grad));

  // Add the appropriate input metadata for the grad_fn.
  for (const auto& tensor : tensors_with_grad) {
    grad_fn->add_input_metadata(tensor);
  }

  // Record the send autograd function in our current context.
  autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
}

5.1.2 設定上下文

我們再回憶一下DistAutogradContext 定義,這裡僅僅給出其部分成員變數。

  • contextId_ 是上下文 id。
  • sendAutogradFunctions_ 是一個 map 型別變數,會收集所有傳送請求對應的反向傳播運算元 SendRpcBackward。
  • recvAutogradFunctions_ 是一個 map 型別變數,會收集所有接受送請求對應的反向傳播運算元 RecvRpcBackward。
// DistAutogradContext which stores information for a single distributed
// autograd pass on a worker.
class TORCH_API DistAutogradContext {

  const int64_t contextId_;

  // Map from autograd_message_id to appropriate 'send' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
      sendAutogradFunctions_;

  // Map from autograd_message_id to appropriate 'recv' autograd function.
  std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
      recvAutogradFunctions_;
};

addSendFunction 就是往 sendAutogradFunctions_ 之中新增SendRpcBackward,後續可以按照 message id 來得到這個 SendRpcBackward。

void DistAutogradContext::addSendFunction(
    const std::shared_ptr<SendRpcBackward>& func,
    int64_t autograd_message_id) {

  std::lock_guard<std::mutex> guard(lock_);
  TORCH_INTERNAL_ASSERT(
      sendAutogradFunctions_.find(autograd_message_id) ==
      sendAutogradFunctions_.end());
  sendAutogradFunctions_.emplace(autograd_message_id, func);
}

前面是從上下文構建的角度看,本次從上下文內容來看。

此時傳送端邏輯如下:

+--------------------------------------------------------------+    +-------------------+
| worker                                                       |    |SendRpcBackward    |
| +---------------------------------------------------------+  |    |                   |
| | DistAutogradContext                                     |  |    |   input_metadata_ |
| |                                                 +-------------> |                   |
| |  contextId_ = context_id_1                      |       |  |    |   next_edges_     |
| |                                                 +       |  |    |                   |
| |  sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] |  |    +-------------------+
| |                                                         |  |
| |                                                         |  |
| |  recvAutogradFunctions_                                 |  |
| |                                                         |  |
| +---------------------------------------------------------+  |
|                                                              |
+--------------------------------------------------------------+

                                                                                  sender
+---------------------------------------------------------------------------------------+

5.2 接受

我們略過 agent 的傳送內部處理,轉而看看 FORWARD_AUTOGRAD_REQ 的業務流程。

5.2.1 接收訊息 ---> 接收方

生成 TensorPipeAgent 時候,把 RequestCallbackImpl 配置為回撥函式。這是 agent 的統一響應函式。

前面關於代理接收邏輯時候,我們也提到了,會進入以下函式,其中可以看到有對 processForwardAutogradReq 的處理邏輯。

void RequestCallbackNoPython::processRpc(
    RpcCommandBase& rpc,
    const MessageType& messageType,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {

    case MessageType::FORWARD_AUTOGRAD_REQ: {
      // 會來到這裡
      processForwardAutogradReq(rpc, messageId, responseFuture, std::move(ctx));
      return;
    }
    case MessageType::BACKWARD_AUTOGRAD_REQ: {
      processBackwardAutogradReq(rpc, messageId, responseFuture);
      return;
    };  
  
}  

5.2.2 處理訊息

processForwardAutogradReq 負責具體處理訊息,其處理邏輯如下:

  • 雖然是收到了前向傳播請求,但因為此處是接收端,後續需要進行反向傳播,所以對deviceMap進行轉置。
  • 使用 addRecvRpcBackward 將 rpc 訊息 加入上下文。
  • 可能會有nested命令的可能,所以需要再呼叫一次processRpc。
  • 設定最原始的訊息為處理完畢,進行相關操作。
void RequestCallbackNoPython::processForwardAutogradReq(
    RpcCommandBase& rpc,
    const int64_t messageId,
    const c10::intrusive_ptr<JitFuture>& responseFuture,
    std::shared_ptr<LazyStreamContext> ctx) const {
  
  auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);

  // Need to reverse the device map for the backward pass of distributed
  // autograd.
  std::unordered_map<c10::Device, c10::Device> reverseDeviceMap;
  // 對deviceMap進行轉置
  for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
    reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
  }

  // Attach 'recv' autograd function.
  auto autogradContext = addRecvRpcBackward( // 呼叫了 addRecvRpcBackward 加入上下文
      rpcWithAutograd.autogradMetadata(),
      rpcWithAutograd.tensors(),
      rpcWithAutograd.fromWorkerId(),
      reverseDeviceMap);
  // For this recv thread on server side, before processRpc(),
  // set current_context_id_ to be context_id passed from client.
  // In this way, if there is nested rpc call in python rpc call, original
  // context_id from client can be passed in the chain calls.
  DistAutogradContextGuard ctxGuard(autogradContext->contextId());

  // Process the original RPC.
  auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
  // Make an overall future for the wrapped response.
  auto wrappedRpcResponseFuture =
      c10::make_intrusive<JitFuture>(at::AnyClassType::get());
  // Kick off processing for the nested RPC command.
  // wrappedRpcResponseFuture will be a Future<T> to the result.
  processRpc( // 可能會有nested命令的可能,所以需要再處理一次
      rpcWithAutograd.wrappedRpc(),
      wrappedMessageType,
      messageId,
      wrappedRpcResponseFuture,
      std::move(ctx));

  auto fromWorkerId = rpcWithAutograd.fromWorkerId();
  // The original future needs to be marked as completed when the wrapped
  // one completes, with the autograd context information wrapped.
  wrappedRpcResponseFuture->addCallback(
      [responseFuture,
       messageId,
       fromWorkerId,
       ctxId =
           autogradContext->contextId()](JitFuture& wrappedRpcResponseFuture) {
        // As this callback can be invoked by a different thread, we have to
        // make sure that the thread_local states in the previous thread is
        // correctly propagated.
        // NB: The execution of TorchScript functions can also run on a
        // different thread, which is addressed by
        // https://github.com/pytorch/pytorch/pull/36395
        // NB: when adding async UDF support, we should also propagate
        // thread_local states there.
        // TODO: Land on a general solution for RPC ThreadLocalState. See
        // https://github.com/pytorch/pytorch/issues/38510
        DistAutogradContextGuard cbCtxGuard(ctxId);

        if (wrappedRpcResponseFuture.hasError()) {
          // Propagate error to responseFuture if we had one.
          responseFuture->setError(wrappedRpcResponseFuture.exception_ptr());
        } else {
          auto msg = getMessageWithAutograd(
              fromWorkerId,
              std::move(
                  *wrappedRpcResponseFuture.value().toCustomClass<Message>()),
              MessageType::FORWARD_AUTOGRAD_RESP);
          msg.setId(messageId);
          responseFuture->markCompleted(
              IValue(c10::make_intrusive<Message>(std::move(msg))));
        }
      });
}

5.2.3 上下文互動

torch/csrc/distributed/autograd/utils.cpp 之中,addRecvRpcBackward 函式會對上下文進行處理。

這裡對應設計中的:

在前向傳播期間,我們在上下文中儲存每個 autograd 傳播的sendrecv函式。這確保我們在 autograd 圖中儲存對適當節點的引用以使其保持活動狀態。除此之外,這也使得在向後傳播期間很容易查詢到對應的sendrecv函式。

其具體邏輯是:

  • 根據 rpc資訊中的 autogradContextId 拿到本地的上下文。
  • 生成一個 RecvRpcBackward。
  • 用 rpc 資訊中的張量來對 RecvRpcBackward 進行配置,包括torch::autograd::set_history(tensor, grad_fn)。
  • 呼叫 addRecvFunction 把 RecvRpcBackward 加入到上下文。
ContextPtr addRecvRpcBackward(
    const AutogradMetadata& autogradMetadata,
    std::vector<torch::Tensor>& tensors,
    rpc::worker_id_t fromWorkerId,
    const std::unordered_map<c10::Device, c10::Device>& deviceMap) {
  // Initialize autograd context if necessary.
  auto& autogradContainer = DistAutogradContainer::getInstance();
  auto autogradContext =
      autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);

  if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
    // Attach the tensors as inputs to the autograd function.
    auto grad_fn = std::make_shared<RecvRpcBackward>(
        autogradMetadata, autogradContext, fromWorkerId, deviceMap);
    for (auto& tensor : tensors) {
      if (tensor.requires_grad()) {
        torch::autograd::set_history(tensor, grad_fn);
      }
    }

    // Now update the autograd context with the necessary information.
    autogradContext->addRecvFunction(
        grad_fn, autogradMetadata.autogradMessageId);
  }

  return autogradContext;
}

addRecvFunction 的新增操作如下,就是看看 recvAutogradFunctions_之中是否已經存在這個 message id 對應的運算元,如果沒有就新增 。

void DistAutogradContext::addRecvFunction(
    std::shared_ptr<RecvRpcBackward>& func,
    int64_t autograd_message_id) {
  TORCH_INTERNAL_ASSERT(func != nullptr);
  std::lock_guard<std::mutex> guard(lock_);
  TORCH_INTERNAL_ASSERT(
      recvAutogradFunctions_.find(autograd_message_id) ==
      recvAutogradFunctions_.end());
  recvAutogradFunctions_.emplace(autograd_message_id, func);
}

至此,邏輯擴充如下,在傳送端和接收端都有一個 DistAutogradContext,其 id 都是 context_id_1。

在 每個 DistAutogradContext 之內,均以 msg_id_1 作為key,一個是 SendRpcBackward,一個建立了 RecvRpcBackward。

這就對應了設計之中提到的:

每個自動微分過程被賦予一個唯一的 autograd_context_id,在容器中,這個微分過程的上下文(DistAutogradContext) 依據這個autograd_context_id 來唯一確認。autograd_context_id 是一個 64 bit 的全域性唯一id,前 16 bis 是 worker_id,後 48 位是在每個worker內部自動遞增id。所以可見,一個Container 之中,是有多個Context的。

此容器還負責維護全域性唯一的訊息id,用來關聯傳送/接收自動微分函式對。格式類似於autograd_context_id,是一個64位整數,前16位是工作者id,後48位是worker內部自動遞增的。

+----------------------------------------------------------------+
| worker                                                         |    +-------------------+
|                                                                |    |SendRpcBackward    |
|   +---------------------------------------------------------+  |    |                   |
|   | DistAutogradContext                                     |  |    |   input_metadata_ |
|   |                                                 +-------------> |                   |
|   |  contextId_ = context_id_1                      |       |  |    |   next_edges_     |
|   |                                                 +       |  |    |                   |
|   |  sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] |  |    +-------------------+
|   |                                                         |  |
|   |  recvAutogradFunctions_                                 |  |
|   |                                                         |  |
|   +---------------------------------------------------------+  |
|                                                                |
|                             +                                  |
|                             |                                  |
+----------------------------------------------------------------+
                              |
                              |
                              |                                                     Sender
+-----------------------------------------------------------------------------------------+
                              |                                                     Receiver
                              |
                              v
+-----------------------------+----------------------------------+
| worker                                                         |
|                                                                |    +-------------------+
|   +---------------------------------------------------------+  |    |RecvRpcBackward    |
|   | DistAutogradContext                                     |  |    |                   |
|   |                                                         |  |    |                   |
|   |   contextId_ = context_id_1                 +-----------------> |   input_metadata_ |
|   |                                             |           |  |    |                   |
|   |   sendAutogradFunctions_                    |           |  |    |   next_edges_     |
|   |                                             +           |  |    |                   |
|   |   recvAutogradFunctions_ = [msg_id_1, RecvRpcBackward_1]|  |    +-------------------+
|   |                                                         |  |
|   +---------------------------------------------------------+  |
|                                                                |
+----------------------------------------------------------------+

我們加入 Container,再擴充一下目前邏輯如下:

  • 每個worker 包括一個DistAutogradContainer。
  • 每個 DistAutogradContainer 包括若干個 DistAutogradContext,依據 context id 提取 DistAutogradContext。
  • 每個 DistAutogradContext 包括 sendAutogradFunctions_ 和 recvAutogradFunctions_,利用 msg id 來獲取 SendRpcBackward 或者 RecvRpcBackward。

這樣這個反向傳播鏈條就構建了出來。

+------------------------------------------------------------------------------------------------------------------------------------+
| worker                                                                                                                             |
|                                                                                                                                    |
| +---------------------------------------+     +---------------------------------------------------------+    +-------------------+ |
| | DistAutogradContainer                 |     | DistAutogradContext                                     |    |SendRpcBackward    | |
| |                                       |     |                                                 +----------> |                   | |
| |   worker_id_                          |     |  contextId_ = ctx_id_1                          |       |    |   input_metadata_ | |
| |                                       |     |                                                 +       |    |                   | |
| |   next_autograd_message_id_     +---------> |  sendAutogradFunctions_ = [msg_id_1, SendRpcBackward_1] |    |   next_edges_     | |
| |                                 |     |     |                                                         |    |                   | |
| |   next_context_id_              |     |     |  recvAutogradFunctions_                                 |    +-------------------+ |
| |                                 +     |     |                                                         |                          |
| |   autograd_contexts_[ctx_id_1 : ctx]  |     +---------------------------------------------------------+                          |
| |                                       |                                                                                          |
| +----------------------------+----------+                                                                                          |
|                              |                                                                                                     |
+------------------------------------------------------------------------------------------------------------------------------------+
                               |
                               |
+-------------------------------------------------------------------------------------------------------------------------------------+
                               |
                               v
+------------------------------+-----------------------------------------------------------------------------------------------------+
| worker                                                                                                                             |
|                                                                                                                                    |
| +---------------------------------------+     +---------------------------------------------------------+    +-------------------+ |
| | DistAutogradContainer                 |     | DistAutogradContext                                     |    |RecvRpcBackward    | |
| |                                       |     |                                                 +----------> |                   | |
| |   worker_id_                          |     |  contextId_ = ctx_id_1                          |       |    |   input_metadata_ | |
| |                                       |     |                                                 |       |    |                   | |
| |   next_autograd_message_id_     +---------> |  sendAutogradFunctions_                         |       |    |   next_edges_     | |
| |                                 |     |     |                                                 +       |    |                   | |
| |   next_context_id_              |     |     |  recvAutogradFunctions_ = [msg_id_1, RecvRpcBackward_1] |    +-------------------+ |
| |                                 +     |     |                                                         |                          |
| |   autograd_contexts_[ctx_id_1 : ctx]  |     +---------------------------------------------------------+                          |
| |                                       |                                                                                          |
| +---------------------------------------+                                                                                          |
|                                                                                                                                    |
+------------------------------------------------------------------------------------------------------------------------------------+

手機如下:

至此,我們初步分析了上下文相關的類,下文我們把目前已經分析的內容結合起來,系統看看業務邏輯。

0xFF 參考

相關文章