[原始碼解析] PyTorch 分散式(2) --- 資料載入之DataLoader

羅西的思考發表於2021-08-18

[原始碼解析] PyTorch 分散式(2) --- 資料載入之DataLoader

0x00 摘要

為了更好的介紹引數伺服器Paracel的資料載入,我們臨時插入兩篇PyTorch的資料載入,主要是從分散式的角度進行切入。本文只算是開胃甜點,後續會有專門系列分析PyTorch分散式。

引數伺服器系列其他文章如下:

[原始碼解析] 機器學習引數伺服器ps-lite 之(1) ----- PostOffice

[原始碼解析] 機器學習引數伺服器ps-lite(2) ----- 通訊模組Van

[原始碼解析] 機器學習引數伺服器ps-lite 之(3) ----- 代理人Customer

[原始碼解析]機器學習引數伺服器ps-lite(4) ----- 應用節點實現

[原始碼解析] 機器學習引數伺服器 Paracel (1)-----總體架構

[原始碼解析] 機器學習引數伺服器 Paracel (2)--------SSP控制協議實現

[原始碼解析] PyTorch 分散式(1) --- 資料載入之DistributedSampler

0x01 前情回顧

關於資料載入,上回書我們說到了 DistributedSampler,本文接下來就進行 DataLoader的分析。

為了更好說明,我們首先給出上文的流水線圖,本文會對這個圖進行細化。

                    +------------+
+--------+          |            |
|        |          | Process 1  |
| Data 1 +--------> |            +------+
|        |          | Load Data  |      |
+--------+          |            |      |
                    +------------+      |
                                        |
                                        |
                                        |
                    +------------+      |        +-----------------------------------+
+--------+          |            |      |        |                                   |
|        |          | Process 2  |      +------> | Pin-memory process                |
| Data 2 +--------> |            |               |                                   |
|        |          | Load Data  +-------------> |                                   |
+--------+          |            |               |        Transfer to Pinned Memory  |
                    +------------+       +-----> |                                   |
                                         |       |                                   |
                                         |       +-----------------------------------+
                                         |
+--------+          +------------+       |
|        |          |            |       |
| Data 3 +--------> | Process 3  +-------+
|        |          |            |
+--------+          | Load Data  |
                    |            |
                    +------------+

其次,我們再看看資料載入總體邏輯,具體如下圖,簡要說就是:

  1. DataSet 把資料集數目發給DistributedSampler。
  2. Sampler 按照某種規則生成資料indices併傳送給DataLoader。
  3. DataLoader 依據indices來從DataSet之中載入資料(其內部的DataLoaderIter物件負責協調單程式/多程式載入Dataset)。
  4. DataLoader 把資料發給模型,進行訓練。
+------------------------+                     +-----------+
|DistributedSampler      |                     |DataLoader |
|                        |     2 indices       |           |
|    Some strategy       +-------------------> |           |
|                        |                     |           |
|-------------+----------|                     |           |
              ^                                |           |  4 data  +-------+
              |                                |       -------------->+ train |
            1 | length                         |           |          +-------+
              |                                |           |
+-------------+----------+                     |           |
|DataSet                 |                     |           |
|        +---------+     |      3 Load         |           |
|        |  Data   +-------------------------> |           |
|        +---------+     |                     |           |
|                        |                     |           |
+------------------------+                     +-----------+

接下來,我們就正式進入 DataLoader。

0x02 DataLoader

DataLoader的作用是:結合Dataset和Sampler之後,在資料集上提供了一個迭代器

可以這麼理解:

DataSet 是原始資料,Sampler 提供瞭如何切分資料的策略(或者說是提供了切分資料的維度),DataLoader就是依據策略來具體打工幹活的,其中單程式載入就是一個人幹活,多程式載入就是多拉幾個人一起幹活

2.1 初始化

初始化的主要引數如下:

  • dataset (Dataset) :所載入的資料集。
  • batch_size (int, optional) :每個批次載入多少個樣本。
  • shuffle (bool, optional) :如果為 True,則每個epoch 都會再打亂資料。
  • sampler (Sampler or Iterable, optional) :定義瞭如何從樣本取樣的策略。可以是任何實現了 __len__的迭代器。
  • batch_sampler (Sampler or Iterable, optional) :與sampler類似,但是每次返回一個批次的資料索引。
  • num_workers (int, optional) :資料載入的子程式數目。如果是 0,表示從主程式載入資料。
  • collate_fn (callable, optional):從一個小批次( mini-batch)張量中合併出一個樣本列表。當從 map-style 資料集做批量載入時候使用。
  • pin_memory (bool, optional) : 如果為true,則在返回張量之前把張量拷貝到CUDA固定記憶體之中。
  • drop_last (bool, optional) :當資料集不能被均勻分割時,如果為true,丟掉最後一個不完整的批次。如果為False,那麼最後一個批次的資料較小。
  • timeout (numeric, optional): 如果是整數,則是worker收集批次資料的超時值。
  • worker_init_fn (callable, optional):如果非空,則會在seeding和資料載入之前被每個子程式呼叫,以Iworker id ([0, num_workers - 1])作為輸入引數。
  • generator (torch.Generator, optional):如果非空,則被RandomSampler 用來產生隨機索引,也被多程式用來產生 base_seed
  • prefetch_factor (int, optional, keyword-only arg):每個 worker 提前載入 的 sample 數量。
  • persistent_workers (bool, optional):如果為 True, 則在消費一次之後,data loader也 不會關掉worker程式。這允許workerDataset例項維持活動狀態。

具體初始化程式碼如下,主要就是各種設定,為了更好的說明,去除了異常處理程式碼:

class DataLoader(Generic[T_co]):

    dataset: Dataset[T_co]
    batch_size: Optional[int]
    num_workers: int
    pin_memory: bool
    drop_last: bool
    timeout: float
    sampler: Sampler
    prefetch_factor: int
    _iterator : Optional['_BaseDataLoaderIter']
    __initialized = False

    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
                 batch_sampler: Optional[Sampler[Sequence[int]]] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):
        torch._C._log_api_usage_once("python.data_loader")

        self.dataset = dataset
        self.num_workers = num_workers
        self.prefetch_factor = prefetch_factor
        self.pin_memory = pin_memory
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        self.multiprocessing_context = multiprocessing_context

        if isinstance(dataset, IterableDataset):
            self._dataset_kind = _DatasetKind.Iterable
			# 省略異常處理
        else:
            self._dataset_kind = _DatasetKind.Map

        if batch_sampler is not None:
            # auto_collation with custom batch_sampler
			# 省略異常處理
            batch_size = None
            drop_last = False
        elif batch_size is None:
            # no auto_collation
            if drop_last:
                raise ValueError('batch_size=None option disables auto-batching '
                                 'and is mutually exclusive with drop_last')

        if sampler is None:  # give default samplers
            if self._dataset_kind == _DatasetKind.Iterable:
                # See NOTE [ Custom Samplers and IterableDataset ]
                sampler = _InfiniteConstantSampler()
            else:  # map-style
                if shuffle:
                    sampler = RandomSampler(dataset, generator=generator)  
                else:
                    sampler = SequentialSampler(dataset) 

        if batch_size is not None and batch_sampler is None:
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.batch_size = batch_size
        self.drop_last = drop_last
        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.generator = generator

        if collate_fn is None:
            if self._auto_collation:
                collate_fn = _utils.collate.default_collate
            else:
                collate_fn = _utils.collate.default_convert

        self.collate_fn = collate_fn
        self.persistent_workers = persistent_workers
        self.__initialized = True
        self._IterableDataset_len_called = None 
        self._iterator = None
        self.check_worker_number_rationality()

2.2 關鍵函式

這裡關鍵函式之一就是_index_sampler,用來讓迭代器呼叫sampler,我們接下來就會講到

    @property
    def _index_sampler(self):
        # The actual sampler used for generating indices for `_DatasetFetcher`
        # (see _utils/fetch.py) to read data at each time. This would be
        # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
        # We can't change `.sampler` and `.batch_sampler` attributes for BC
        # reasons.
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler

2.3 單程式載入

單程式模式下,Data Loader會在計算程式內載入資料,所以載入過程中可能會阻塞計算。

for 語句會呼叫enumerate 會返回一個迭代器,以此來遍歷資料集。在eumerate之中,dataloader 的 __next__(self) 方法會被呼叫,逐一獲取下一個物件,從而遍歷資料集。

    cuda0 = torch.device('cuda:0')  # CUDA GPU 0
    for i, x in enumerate(train_loader):
        x = x.to(cuda0)

2.3.1 區分生成

當多程式載入時候,在DataLoader宣告週期之中,迭代器只被建立一次,這樣worker可以重用迭代器。

在單程式載入時候,應該每次生成,以避免重置狀態。

    def __iter__(self) -> '_BaseDataLoaderIter':
        if self.persistent_workers and self.num_workers > 0: # 如果是多程式或者設定了持久化
            if self._iterator is None: # 如果沒有,才會新生成
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else: # 單程式
            return self._get_iterator() # 每次都直接生成新的

具體會依據是否是多程式來區別生成。

    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

2.3.2 迭代器基類

_BaseDataLoaderIter 是迭代器基類,我們挑選關鍵函式看看。

這裡關鍵成員變數就是:

  • _index_sampler:這裡設定了loader 的 sampler,所以迭代器可以據此獲取取樣策略。
  • _sampler_iter:得到 sampler 的迭代器。
class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        # 初始化引數
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler # 得到取樣策略
        self._num_workers = loader.num_workers
        self._prefetch_factor = loader.prefetch_factor
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler) # 得到sampler的迭代器
        self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
        self._persistent_workers = loader.persistent_workers
        self._num_yielded = 0
        self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)


    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data() # 獲取資料
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
					# 忽略錯誤提示處理
            	warnings.warn(warn_msg)
            return data

2.3.3 單程式迭代器

_SingleProcessDataLoaderIter 繼承了 _BaseDataLoaderIter,可以看到,其增加了 _dataset_fetcher,在構造時候傳入了 _collate_fn 等各種引數。

回憶下,__next__會呼叫 self._next_data() 獲取資料,而在這裡,_next_data 就會:

  • 使用 self._next_index(),其又會使用 _sampler_iter(取樣器的迭代器)來獲取indices 。
  • 使用 self._dataset_fetcher.fetch(index)來依據indices獲取資料。
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        # 獲取樣本方法
        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        # 獲取樣本
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data
    
    def _next_index(self): # 得到indices
        return next(self._sampler_iter)  # may raise StopIteration    

2.3.4 獲取樣本

我們接下來看看如何獲取樣本。就是通過索引傳入 fetcher,從而獲取想要的樣本。

fetcher生成如下,這是在_SingleProcessDataLoaderIter初始化時候生成的:

class _DatasetKind(object):
    Map = 0
    Iterable = 1

    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

對於Map-style,就使用 _MapDatasetFetcher 處理,就是使用 possibly_batched_index 從資料集之中提取資料,possibly_batched_index 是key。

如果有batch sampler,就使用 batch sampler。

如果需要從一個小批次( mini-batch)張量中合併出一個樣本列表。就使用 collate_fn後處理。

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
			# 如果配置了batch_sampler,_auto_collation就為True,
            # 那麼就優先使用batch_sampler,此時fetcher中傳入的就是一個batch的索引
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

對於 Iterable-style,因為 __init__ 方法內設定了 dataset 初始的迭代器,所以在fetch 方法內獲取元素的時候,如果是常規 sampler,index 其實已經不起作用,直接從dataset迭代器獲取。如果是batch sampler,則index有效果。

class _IterableDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
        self.dataset_iter = iter(dataset)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            # 即auto_collation為True,表示使用batch_sampler。
            # 則使用possibly_batched_index,獲取1個batch大小的樣本       
            data = []
            for _ in possibly_batched_index:
                try:
                    data.append(next(self.dataset_iter))
                except StopIteration:
                    break
            if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                raise StopIteration
        else:
            # sampler則直接往後遍歷,提取1個樣本
            data = next(self.dataset_iter)
        return self.collate_fn(data)

此時總邏輯如下:

     +--------------------------+            +-------------------------------+
     | DataLoader               |            | _SingleProcessDataLoaderIter  |
     |                          |            |                               |
     |                          |            |               __next__        |
+---------------+ Sampler       |            |                               |
|    |                          |            |              _next_data +-----------+
|    |            Dataset       |            |                               |     |
|    |                          |            |              _next_index      |     |
|    |           __iter__       |            |                               |     |
|    |                          |            |             _index_sampler    |     |
|    |       _get_iterator  +--------------> |                    +          |     |
|    |                          |            |                    |          |     |
|    +--------------------------+            +-------------------------------+     |
|                                                                 |                |
|                                                                 |                |
|                                                                 |                |
|                                                                 |                |
|                                                                 |                |
|                           +----------------------------+        |                |
|                           |Sampler                     |        |                |
+------------------------>  |                            | <------+                |
                            |                            |                         |
                            |                            |                         |
                            |                            |                         |
                            +----------------------------+                         |
                                                                                   |
                                                                                   |
                            +----------------------------+                         |
                            |_BaseDatasetFetcher         |                         |
                            |                            |                         |
                            |                            |                         |
                            |          dataset           |                         |
                            |                            |  <----------------------+
                            |          collate_fn        |
                            |                            |
                            +----------------------------+

動態流程如下:

  User              DataLoader    _SingleProcessDataLoaderIter _DatasetKind   Sampler

    +                   +                    +                        +           +
    |                   |                    |                        |           |
    |         1         |                    |                        |           |
 enumerate-------->  __iter__                |                        |           |
    |                   +                    |                        |           |
    |                   |                    |                        |           |
    |                   |                    |                        |           |
    |                   |          2         v            3           v           |
    |              _get_iterator--------> __init__  +----------> create_fetcher   |
    |         4         |                    +                        +           |
    | <-----------------+                    |                        |           |
    |      iterator     |                    |                        |           |
    |                   |          5         |                        |           |
for loop +------------------------------> __next__                    |           |
    |                   |                    |                        |           |
    |                   |                    |                        |           |
    |                   |                    |                        |           |
    |                   |                _next_data                   |           |
    |                   |                    |                        |           |
    |                   |                    |                        |           |
    |                   |                    |           6  next      |           |
    |                   |                _next_index  +-------------------------> |
    |                   |                    |                        |           |
    |                   |                    |  <---------------------------------+
    |                   |                    |           7  index     |           |
    |                   |                    |                        |           |
    |                   |                    |                        |           |
    |                   |                    |        8 fetch(index)  |           |
    |                   |                    | +--------------------> |           |
    |                   |                    |                        |           |
    |                   |                    |  <---------------------+           |
    |                   |                    |         9  data        |           |
    |  <-------------------------------------+                        |           |
    |   10  data        |                    |                        |           |
    |                   |                    |                        |           |
    v                   v                    v                        v           v


2.4 多程式載入

為了加速,PyTorch提供了多程式下載,只要把將引數 num_workers 設定為正整數,系統就會相應生成多程式處理,在這種模式下,每個worker都是一個獨立程式。

由上節我們可以知道,_SingleProcessDataLoaderIter 是單程式載入資料的核心,loader通過它來與sampler,dataset互動。在多程式中,這個核心對應的就是 _MultiProcessingDataLoaderIter。

    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

我們接下來就從 _MultiProcessingDataLoaderIter 開始分析。

2.4.1 總體邏輯

_MultiProcessingDataLoaderIter 中的註釋十分詳盡,值得大家深讀,而且給出了邏輯流程圖如下,其基本流程是圍繞著三個queue進行的:

  • 主程式把需要獲取的資料 index 放入index_queue,這是指定子程式需要獲取哪些資料的佇列。同時也給子程式傳入結果佇列,關於結果佇列,有兩個分支:
    • 如果設定了pin memory,則傳入的是 worker_result_queue。
    • 否則傳入 data_queue。
  • 子程式從 index_queue 之中讀取 index,進行資料讀取,然後把讀取資料的index放入worker_result_queue,這是向主程式返回結果的佇列
  • 主程式進行處理,這裡有兩個分支:
    • 如果設定了pin memory,則主程式的 pin_memory_thread 會從 worker_result_queue 讀取資料index,依據這個index進行讀取資料,進行處理,把結果放入 data_queue,這是處理結果的佇列
    • 如果不需要pin memory,則結果已經存在 data_queue 之中,不做新操作。

可以看到,每個程式的輸入是一個佇列index_queue ,輸出也是一個佇列worker_result_queue。主程式和子程式通過這2~3個 queue 聯絡了起來,從而達到解耦合和加速的作用

    # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
    #
    # Preliminary:
    #
    # Our data model looks like this (queues are indicated with curly brackets):
    #
    #                main process                              ||
    #                     |                                    ||
    #               {index_queue}                              ||
    #                     |                                    ||
    #              worker processes                            ||     DATA
    #                     |                                    ||
    #            {worker_result_queue}                         ||     FLOW
    #                     |                                    ||
    #      pin_memory_thread of main process                   ||   DIRECTION
    #                     |                                    ||
    #               {data_queue}                               ||
    #                     |                                    ||
    #                data output                               \/
    #
    # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
    #      `pin_memory=False`.

具體如下圖所示,如果不需要 pin memory,則為:

                                               +-----------+
               indices  -------------+ indices | Worker    | Data
             +--------->+index queue +-------->+ Process   +------+
             |          |            |         |           |      |
             |          -------------+         +-----------+      |
             |                                                    |   +------------+
             |                                                    |   |            |
+---------+  |                                                    +--->            |
| Main    |  | indices  -------------+ indices +-----------+          |            |
| Process +------------>+index queue +-------->+ Worker    | Data     | Data Queue |
|         |  |          |            |         | Process   +---------->            |
+---------+  |          -------------+         |           |          |            |
             |                                 +-----------+      +--->            |
             |                                                    |   +------------+
             |                                                    |
             | indices  -------------+ indices +-----------+      |
             +--------->+index queue +-------->+ Worker    | Data |
                        |            |         | Process   +------+
                        -------------+         |           |
                                               +-----------+

當有pin memory時候,則是先進入 result queue,然後 pin_memory_thread 處理之後會轉入到 data queue:

                                               +-----------+
               indices  -------------+ indices | Worker    | Data
             +--------->+index queue +-------->+ Process   +------+
             |          |            |         |           |      |
             |          -------------+         +-----------+      |
             |                                                    |   --------------+
             |                                                    |   |             |
+---------+  |                                                    +--->             |
| Main    |  | indices  -------------+ indices +-----------+          |             |
| Process +------------>+index queue +-------->+ Worker    | Data     | result_queue|
|         |  |          |            |         | Process   +---------->             |
+---------+  |          -------------+         |           |          |             |
             |                                 +-----------+      +--->             |
             |                                                    |   ---------+----+
             |                                                    |            |
             | indices  -------------+ indices +-----------+      |            |
             +--------->+index queue +-------->+ Worker    | Data |  +---------+--------+
                        |            |         | Process   +------+  | pin_memory_thread|
                        -------------+         |           |         |         |        |
                                               +-----------+         |         |        |
                                                                     |         |        |
                                                                     +------------------+
                                                                               |
                                                                               |
                                                                               |
                                                                               v
                                                                         +-----+------+
                                                                         | Data Queue |
                                                                         |            |
                                                                         +------------+

2.4.2 初始化

初始化函式如下,主要是:

  • 配置,生成各種成員變數,配置各種queue。
  • 啟動各個子程式。
  • 啟動主程式中的pin_memory的執行緒。

主要成員變數為:

  • _index_queues: 這是一個queue 列表,列表的每一個元素是一個 queue,就是每個子程式的佇列需要處理的資料index,每個子程式對應一個 queue。
  • _worker_result_queue: 子程式處理完的 (idx, data)。
  • data_queue: 經過主程式 pin_memory 執行緒處理之後的資料佇列,如果不需要pin,則直接會使用 _worker_result_queue
  • _worker_queue_idx_cycle 用以找出下一個工作的worker。

具體程式碼如下:

class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
    r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""

    def __init__(self, loader):
        super(_MultiProcessingDataLoaderIter, self).__init__(loader)

        assert self._num_workers > 0
        assert self._prefetch_factor > 0

        if loader.multiprocessing_context is None:
            multiprocessing_context = multiprocessing
        else:
            multiprocessing_context = loader.multiprocessing_context

        self._worker_init_fn = loader.worker_init_fn
        self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
        # No certainty which module multiprocessing_context is
        self._worker_result_queue = multiprocessing_context.Queue()  # 子程式輸出,讀取完資料的index
        self._worker_pids_set = False
        self._shutdown = False
        self._workers_done_event = multiprocessing_context.Event()

        self._index_queues = [] # 子程式輸入,需讀取資料的index
        self._workers = []
        for i in range(self._num_workers):
            # No certainty which module multiprocessing_context is
            index_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]
            # Need to `cancel_join_thread` here!
            # See sections (2) and (3b) above.
            index_queue.cancel_join_thread()
            w = multiprocessing_context.Process(
                target=_utils.worker._worker_loop, # worker程式主函式,把各種queue和函式傳進去
                args=(self._dataset_kind, self._dataset, index_queue,
                      self._worker_result_queue, self._workers_done_event,
                      self._auto_collation, self._collate_fn, self._drop_last,
                      self._base_seed, self._worker_init_fn, i, self._num_workers,
                      self._persistent_workers))
            w.daemon = True
            w.start()
            self._index_queues.append(index_queue) # 把這個worker對應的index_queue放到主程式這裡存起來,以後就可以互動了
            self._workers.append(w)

        if self._pin_memory:
            self._pin_memory_thread_done_event = threading.Event()

            # Queue is not type-annotated
            self._data_queue = queue.Queue()  # pin 處理之後的資料結果
            pin_memory_thread = threading.Thread(
                target=_utils.pin_memory._pin_memory_loop,
                args=(self._worker_result_queue, self._data_queue,
                      torch.cuda.current_device(),
                      self._pin_memory_thread_done_event))
            pin_memory_thread.daemon = True
            pin_memory_thread.start()
            # Similar to workers (see comment above), we only register
            # pin_memory_thread once it is started.
            self._pin_memory_thread = pin_memory_thread
        else:
            self._data_queue = self._worker_result_queue # 如果不需要pin,則直接使用_worker_result_queue

        # .pid can be None only before process is spawned (not the case, so ignore)
        _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))  # type: ignore[misc]
        _utils.signal_handling._set_SIGCHLD_handler()
        self._worker_pids_set = True
        
        self._reset(loader, first_iter=True) # 繼續完善業務

2.4.3 業務重置

__init__ 函式最後會呼叫 _reset 函式,這是進一步完善業務初始化,也用來重置環境。

上小節函式中,已經啟動了worker子程式,但是沒有分配任務,所以_reset函式會進行任務分配,預取。

_MultiProcessingDataLoaderIter有如下 flag 引數來協調各個 worker (包括各種queue)之間的工作:

  • _send_idx: 傳送索引,用來記錄這次要放 index_queue 中 batch 的 idx

  • _rcvd_idx: 接受索引,記錄要從 data_queue 中取出的 batch 的 idx

  • _task_info: 儲存將要產生的 data 資訊的 dict,key為 task idx(由 0 開始的整型索引),value 為 (worker_id,)(worker_id, data),分別對應資料 未取 和 已取 的情況

  • _tasks_outstanding: 整型,代表已經準備好的 task/batch 的數量(可能有些正在準備中)

  • _send_idx: 傳送索引,記錄下一次要放 index_queue 中 task batch 的 idx。

  • _rcvd_idx: 接受索引,記錄下一次要從 data_queue 中取出的 task batch 的 idx。_send_idx_rcvd_idx 主要用來進行流量控制和確保接受索引有意義。

  • _task_info: 儲存將要產生的 data 資訊的 dict,key為 task batch idx(由 0 開始的整型索引),value 為 (worker_id,)(worker_id, data),分別對應資料 未取 和 已取 的情況。_task_info的作用是依據 task batch idx 獲取對應的 worker id 和暫存亂序資料。

  • _tasks_outstanding: 整型,正在準備的 task/batch 的數量,實際上就是進行一些確認工作,沒有太實際的意義。

對於載入資料,每個 worker 一次產生一個 batch 的資料,返回 batch 資料前,會放入下一個批次要處理的資料下標,所以 reset 函式會把 _send_idx_rcvd_idx 都恢復成0,這樣下次迭代就可以重新處理。

在 reset 方法最後,有一個預取資料操作。我們會在後面結合亂序處理進行講解

    def _reset(self, loader, first_iter=False):
        super()._reset(loader, first_iter)
        self._send_idx = 0  # idx of the next task to be sent to workers
        self._rcvd_idx = 0  # idx of the next task to be returned in __next__
        # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
        # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
        #                  \ (worker_id, data)   if data is already fetched (out-of-order)
        self._task_info = {}
        self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
        # A list of booleans representing whether each worker still has work to
        # do, i.e., not having exhausted its iterable dataset object. It always
        # contains all `True`s if not using an iterable-style dataset
        # (i.e., if kind != Iterable).
        # Not that this indicates that a worker still has work to do *for this epoch*.
        # It does not mean that a worker is dead. In case of `_persistent_workers`,
        # the worker will be reset to available in the next epoch.
        # 每個worker的狀態
        self._workers_status = [True for i in range(self._num_workers)]
        # We resume the prefetching in case it was enabled
        if not first_iter:
            for idx in range(self._num_workers):
                self._index_queues[idx].put(_utils.worker._ResumeIteration())
            resume_iteration_cnt = self._num_workers
            while resume_iteration_cnt > 0:
                return_idx, return_data = self._get_data()
                if isinstance(return_idx, _utils.worker._ResumeIteration):
                    assert return_data is None
                    resume_iteration_cnt -= 1
        # prime the prefetch loop
        
        # 預取若干index,目的是為了配合後續的亂序處理。
        for _ in range(self._prefetch_factor * self._num_workers):
            self._try_put_index()

2.4.4 獲取 index

_try_put_index 函式就是使用sampler獲取下一批次的資料index。這裡 _prefetch_factor 預設值是 2,主要邏輯如下。

  • 從sampler獲取下一批次的index。
  • 通過 _worker_queue_idx_cycle 找出下一個可用的工作worker,然後把index分給它。
  • 並且調整主程式的資訊。
    def _next_index(self): # 定義在基類 _BaseDataLoaderIter 之中,就是獲取下一批index
        return next(self._sampler_iter)  # may raise StopIteration

	def _try_put_index(self):
        
        assert self._tasks_outstanding < self._prefetch_factor * self._num_workers

        try:
            index = self._next_index() # 獲取下一批index
        except StopIteration:
            return
        for _ in range(self._num_workers):  # find the next active worker, if any
            worker_queue_idx = next(self._worker_queue_idx_cycle)
            if self._workers_status[worker_queue_idx]: # 如果已經工作,就繼續找
                break
        else:
            # not found (i.e., didn't break)
            return

        # 以下是主程式進行相關記錄
        # 給下一個工作worker放入 (任務index, 資料index), 就是給queue放入資料,所以worker loop之中就立刻會從queue中得到index,從而開始獲取資料。
        self._index_queues[worker_queue_idx].put((self._send_idx, index)) 
        # 記錄 將要產生的 data 資訊
        self._task_info[self._send_idx] = (worker_queue_idx,)
        # 正在處理的batch個數+1
        self._tasks_outstanding += 1
        # send_idx 記錄從sample_iter中傳送索引到index_queue的次數
        self._send_idx += 1 # 遞增下一批傳送的task index

2.4.5 worker主函式

_worker_loop 是 worker程式的主函式,主要邏輯如其註釋所示:

    # [ worker processes ]
    #   While loader process is alive:
    #     Get from `index_queue`.
    #       If get anything else,
    #          Check `workers_done_event`.
    #            If set, continue to next iteration
    #                    i.e., keep getting until see the `None`, then exit.
    #            Otherwise, process data:
    #                If is fetching from an `IterableDataset` and the iterator
    #                    is exhausted, send an `_IterableDatasetStopIteration`
    #                    object to signal iteration end. The main process, upon
    #                    receiving such an object, will send `None` to this
    #                    worker and not use the corresponding `index_queue`
    #                    anymore.
    #       If timed out,
    #          No matter `workers_done_event` is set (still need to see `None`)
    #          or not, must continue to next iteration.
    #   (outside loop)
    #   If `workers_done_event` is set,  (this can be False with `IterableDataset`)
    #     `data_queue.cancel_join_thread()`.  (Everything is ending here:
    #                                          main process won't read from it;
    #                                          other workers will also call
    #                                          `cancel_join_thread`.)

就是通過index_queue, data_queue與主程式互動。

  • 從 index_queue 獲取新的資料index;
  • 如果沒有設定本worker結束,就使用 fetcher獲取資料
  • 然後把資料放入data_queue,並且通知主程式,這裡需要注意,data_queue是傳入的引數,如果設定了pin memory,則傳入的是 worker_result_queue, 否則傳入 data_queue
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
                 auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
                 num_workers, persistent_workers):
    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
    # logic of this function.

    try:
        # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal had already happened
        # again.
        # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
        signal_handling._set_worker_signal_handlers()

        torch.set_num_threads(1)
        seed = base_seed + worker_id
        random.seed(seed)
        torch.manual_seed(seed)
        if HAS_NUMPY:
            np_seed = _generate_state(base_seed, worker_id)
            import numpy as np
            np.random.seed(np_seed)

        global _worker_info
        _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
                                  seed=seed, dataset=dataset)

        from torch.utils.data import _DatasetKind

        init_exception = None

        try:
            if init_fn is not None:
                init_fn(worker_id)

            fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
        except Exception:
            init_exception = ExceptionWrapper(
                where="in DataLoader worker process {}".format(worker_id))

        iteration_end = False
        watchdog = ManagerWatchdog()

        while watchdog.is_alive(): # 等待在這裡
            try:
                # _try_put_index 如果放入了資料index,這裡就被啟用,開始工作
                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                continue
            if isinstance(r, _ResumeIteration):
                # Acknowledge the main process
                data_queue.put((r, None))
                iteration_end = False
                # Recreate the fetcher for worker-reuse policy
                fetcher = _DatasetKind.create_fetcher(
                    dataset_kind, dataset, auto_collation, collate_fn, drop_last)
                continue
            elif r is None:
                # Received the final signal
                assert done_event.is_set() or iteration_end
                break
            elif done_event.is_set() or iteration_end:
                # `done_event` is set. But I haven't received the final signal
                # (None) yet. I will keep continuing until get it, and skip the
                # processing steps.
                continue
            idx, index = r
            data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
            if init_exception is not None:
                data = init_exception
                init_exception = None
            else:
                try:
                    data = fetcher.fetch(index)
                except Exception as e:
					# 省略處理程式碼
            
            data_queue.put((idx, data)) # 放入資料,通知主程式
            del data, idx, index, r  # save memory
    except KeyboardInterrupt:
        # Main process will raise KeyboardInterrupt anyways.
        pass
    if done_event.is_set():
        data_queue.cancel_join_thread()
        data_queue.close()

2.4.6 Pin memory thread

在主程式之中,如果設定了需要pin memory,主程式的 pin_memory_thread 會從 worker_result_queue 讀取資料,進行處理(加速CPU和GPU的資料拷貝),把結果放入 data_queue。

    # [ pin_memory_thread ]
    #   # No need to check main thread. If this thread is alive, the main loader
    #   # thread must be alive, because this thread is set as daemonic.
    #   While `pin_memory_thread_done_event` is not set:
    #     Get from `index_queue`.
    #       If timed out, continue to get in the next iteration.
    #       Otherwise, process data.
    #       While `pin_memory_thread_done_event` is not set:
    #         Put processed data to `data_queue` (a `queue.Queue` with blocking put)
    #         If timed out, continue to put in the next iteration.
    #         Otherwise, break, i.e., continuing to the out loop.
    #
    #   NOTE: we don't check the status of the main thread because
    #           1. if the process is killed by fatal signal, `pin_memory_thread`
    #              ends.
    #           2. in other cases, either the cleaning-up in __del__ or the
    #              automatic exit of daemonic thread will take care of it.
    #              This won't busy-wait either because `.get(timeout)` does not
    #              busy-wait.

具體程式碼如下:

def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
    # This setting is thread local, and prevents the copy in pin_memory from
    # consuming all CPU cores.
    torch.set_num_threads(1)

    torch.cuda.set_device(device_id)

    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
    # logic of this function.
    while not done_event.is_set():
        try:
            r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        except queue.Empty:
            continue
        idx, data = r
        if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
            data = pin_memory(data)
            # 省略異常處理程式碼
            r = (idx, data)
        while not done_event.is_set():
            try:
                out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
                break
            except queue.Full:
                continue
        del r  # save memory


def pin_memory(data):
    if isinstance(data, torch.Tensor):
        return data.pin_memory()
    elif isinstance(data, string_classes):
        return data
    elif isinstance(data, collections.abc.Mapping):
        return {k: pin_memory(sample) for k, sample in data.items()}
    elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple
        return type(data)(*(pin_memory(sample) for sample in data))
    elif isinstance(data, collections.abc.Sequence):
        return [pin_memory(sample) for sample in data]
    elif hasattr(data, "pin_memory"):
        return data.pin_memory()
    else:
        return data

2.4.7 使用者獲取data

現在資料已經載入完畢,我們接下來看使用者如何從DataLoader之中獲取資料。

這裡有一個很關鍵的地方:如何保持在不同實驗之中資料讀取順序的一致性。為了讓多次實驗之間可以比對,就需要儘量保證在這些實驗中,每次讀取資料的順序都是一致的,這樣才不會因為資料原因造成結果的誤差。

打破順序一致性的最大可能就是亂序資料。而造成亂序問題的原因就是:多程式讀取,可能某個程式快,某個程式慢。比如,使用者這次需要讀取6-19,16-26,37-46。但是某一個worker慢,6-19不能即時返回,另一個worker 的 16-26 先返回了,於是就會造成亂序。

如何處理亂序資料?PyTorch的具體做法就是:DataLoader嚴格按照Sampler的順序返回資料。如果某一個資料是亂序的,則會把它暫存起來,轉而去獲取下一個資料,見下面程式碼中 "store out-of-order samples" 註釋處。等到應該返回時候(這個資料順序到了)才返回。

但是其風險就是資料返回會比當前請求慢,比如應該獲取 6,但是Data queue裡面沒有這個資料,只有 16,27,於是使用者只能等待 6 載入完成。

解決慢的方法是:預取(prefetch)。就是在reset方法最後,提前提取若干index,讓DataLoader提前去取,這雖然不能保證任意兩次訓練的資料返回順序完全一致,但是可以最大限度保證。

具體程式碼如下,首先,回憶基類的 __next__ 函式 ,可以看到其呼叫了 _next_data 獲取資料。

class _BaseDataLoaderIter(object):
    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data() # 獲取資料
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
					# 忽略錯誤提示處理
            	warnings.warn(warn_msg)
            return data

所以,我們要看 _MultiProcessingDataLoaderIter_next_data

  • 因為之前有預取了index,worker程式已經開始獲取資料,所以主程式此時可以得到資料,如果沒有資料,就繼續while True等待。
  • 如果獲取成功,則使用 _process_data 設定下一次的indx,準備下一次迭代。
  • 通過 _task_info 來記錄亂序資料,如果暫時無法處理,就在這裡儲存。
    def _next_data(self):
        while True:
            # If the worker responsible for `self._rcvd_idx` has already ended
            # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
            # we try to advance `self._rcvd_idx` to find the next valid index.
            #
            # This part needs to run in the loop because both the `self._get_data()`
            # call and `_IterableDatasetStopIteration` check below can mark
            # extra worker(s) as dead.
            
            # 找到待取idx
            while self._rcvd_idx < self._send_idx: # 如果 待取batch idx < 已取batch idx
                info = self._task_info[self._rcvd_idx]
                worker_id = info[0]
                if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active
                    break # 有資料或者正在工作,就跳出內部這個while
                del self._task_info[self._rcvd_idx]
                self._rcvd_idx += 1
            else:
                # no valid `self._rcvd_idx` is found (i.e., didn't break)
                if not self._persistent_workers:
                    self._shutdown_workers()
                raise StopIteration

            # Now `self._rcvd_idx` is the batch index we want to fetch

            # Check if the next sample has already been generated
            if len(self._task_info[self._rcvd_idx]) == 2:
                data = self._task_info.pop(self._rcvd_idx)[1]
                return self._process_data(data) # 設定下一次的indx,進行下一次迭代

            assert not self._shutdown and self._tasks_outstanding > 0
            idx, data = self._get_data() # 從 self._data_queue 中取資料
            self._tasks_outstanding -= 1 # 正在準備的batch個數需要減1
            
            if self._dataset_kind == _DatasetKind.Iterable:
                # Check for _IterableDatasetStopIteration
                if isinstance(data, _utils.worker._IterableDatasetStopIteration):
                    if self._persistent_workers:
                        self._workers_status[data.worker_id] = False
                    else:
                        self._mark_worker_as_unavailable(data.worker_id)
                    self._try_put_index() 
                    continue

            if idx != self._rcvd_idx: # 亂序資料
                # store out-of-order samples
                self._task_info[idx] += (data,)
            else:
                del self._task_info[idx] # 正常資料
                return self._process_data(data) # 設定下一次的indx,進行下一次迭代

其次,我們看看 _get_data 如何從 self._data_queue 中取資料。具體是使用 _try_get_data 來提取。

  • 如果有超時配置,就按照超時讀取。
  • 如果設定了pin memory,則從pin 執行緒處理之後的資料讀取。
  • 否則迴圈讀取worker處理的資料,直至獲取到資料為止。
    def _get_data(self):
        # Fetches data from `self._data_queue`.
        #
        # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
        # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
        # in a loop. This is the only mechanism to detect worker failures for
        # Windows. For other platforms, a SIGCHLD handler is also used for
        # worker failure detection.
        #
        # If `pin_memory=True`, we also need check if `pin_memory_thread` had
        # died at timeouts.
        if self._timeout > 0: # 如果有超時配置,就按照超時讀取
            success, data = self._try_get_data(self._timeout)
            if success:
                return data
            else:
                raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
        elif self._pin_memory: # 從pin 執行緒處理之後的資料讀取
            while self._pin_memory_thread.is_alive():
                success, data = self._try_get_data()
                if success:
                    return data
            else:
                # while condition is false, i.e., pin_memory_thread died.
                raise RuntimeError('Pin memory thread exited unexpectedly')
            # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
            # need to call `.task_done()` because we don't use `.join()`.
        else:
            while True:
                success, data = self._try_get_data() # 讀取worker處理的資料
                if success:
                    return data

_try_get_data 就是從 _data_queue 讀取。主程式和worker程式通過queue上的put, get進行通訊互動。

    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
        # Tries to fetch data from `self._data_queue` once for a given timeout.
        # This can also be used as inner loop of fetching without timeout, with
        # the sender status as the loop condition.
        #
        # This raises a `RuntimeError` if any worker died expectedly. This error
        # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
        # (only for non-Windows platforms), or the manual check below on errors
        # and timeouts.
        #
        # Returns a 2-tuple:
        #   (bool: whether successfully get data, any: data if successful else None)
        try:
            data = self._data_queue.get(timeout=timeout)
            return (True, data)
        except Exception as e:
            # At timeout and error, we manually check whether any worker has
            # failed. Note that this is the only mechanism for Windows to detect
            # worker failures.
            failed_workers = []
            for worker_id, w in enumerate(self._workers):
                if self._workers_status[worker_id] and not w.is_alive():
                    failed_workers.append(w)
                    self._mark_worker_as_unavailable(worker_id)
			# 省略異常處理程式碼
            import tempfile
            import errno
            try:
                # Raise an exception if we are this close to the FDs limit.
                # Apparently, trying to open only one file is not a sufficient
                # test.
                # See NOTE [ DataLoader on Linux and open files limit ]
                fds_limit_margin = 10
                fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
            except OSError as e:
                # 省略異常處理程式碼
            raise

設定下一次迭代是使用_process_data

    def _process_data(self, data):
        self._rcvd_idx += 1
        self._try_put_index() # 設定下一次的indx,進行下一次迭代
        if isinstance(data, ExceptionWrapper):
            data.reraise()
        return data # 返回資料

2.4.8 小結

我們小結一下多程式邏輯。

總體邏輯如下:

  • 主程式把需要獲取的資料 index 放入index_queue。
  • 子程式從 index_queue 之中讀取 index,進行資料讀取,然後把讀取資料的index放入worker_result_queue。
  • 主程式的 pin_memory_thread 會從 worker_result_queue 讀取資料index,依據這個index進行讀取資料,進行處理,把結果放入 data_queue。

具體流程如下圖:

  1. 在 _MultiProcessingDataLoaderIter 的初始化函式 __init__ 之中會進行初始化:
    • 配置,生成各種成員變數,配置各種queue。
    • 啟動各個子程式。
    • 啟動主程式中的pin_memory的執行緒。
    • 呼叫 _reset 函式,這是進一步完善業務初始化,也用來重置環境。上面已經啟動了worker子程式,但是沒有分配任務,所以reset函式會進行任務分配,預取
  2. 接下來是一個預取操作(在看下圖中一定要留意)。
    • _try_put_index 函式就是使用sampler獲取下一批次的資料index。這裡 _prefetch_factor 預設值是 2,主要邏輯如下。
      • 使用 _next_index 從sampler獲取下一批次的index。
      • 通過 _worker_queue_idx_cycle 找出下一個可用的工作worker,然後把index分給它。
      • 並且調整主程式的資訊。
    • 拿到index之後,回到主執行緒。這裡會進行資料提取。就是通過index_queue, data_queue與主程式互動。
      • 從 index_queue 獲取新的資料index;
      • 如果沒有設定本worker結束,就使用 fetcher獲取資料。
      • 然後把資料放入data_queue,並且通知主程式,這裡需要注意,data_queue是傳入的引數,如果設定了pin memory,則傳入的是 worker_result_queue,否則傳入 data_queue。
  3. 當使用者迭代時,呼叫了Loader基類的 __next__ 函式 ,其呼叫 _next_data 從 DataLoader 之中獲取資料。
    • 使用 _get_data 如何從 self._data_queue 中取資料。
    • 使用_process_data 設定下一次迭代的 index,即使用 _try_put_index_next_index 來進行下一輪設定。

具體如下圖:

user        _MultiProcessingDataLoaderIter   Sampler        Queue(index_queue)    Queue(data_queue)    _worker_loop     Fetcher
 +                       +                      +                  +                     +                  +              +
 |                       |                      |                  |                     |                  |              |
 |                       |                      |                  |                     |                  |              |
 |                       v                      |                  |                     |                  |              |
 |                   __init__                   |                  |                     |                  |              |
 |               1    _reset                    |                  |                     |                  |              |
 |                       +                      |                  |                     |                  |              |
 |                       |                      |                  |                     |                  |              |
 |                       |                      |                  |                     |                  |              |
 |                       v                      |                  |                     |                  |              |
 |            2   _try_put_index     next       |                  |                     |                  |              |
 |                  _next_index  +------------> |                  |                     |                  |              |
 |                       +                      |                  |                     |                  |              |
 |                       |  <-----------------+ |                  |                     |                  |              |
 |                       |           index      |                  |                     |                  |              |
 |                       |                      |                  |                     |                  |              |
 |                       | +------------------------------------>  |                     |                  |              |
 |                       |           put        |                  |                     |       get        |              |
 |                       |                      |                  +--------------------------------------> |              |
 |                       |                      |                  |                     |                  |    index     |
 |                       |                      |                  |                     |                  +------------> |
 |         next          |                      |                  |                     |                  | <----------+ |
 +---------------------> |                      |                  |                     | <----------------+    data      |
 |                       |                      |                  |                     |      data        |              |
 |                       +                      |                  |                     |                  |              |
 |                   _next_data                 |                  |                     |                  |              |
 |              3   _get_data          get      |                  |                     |                  |              |
 |                  _try_get_data  +-------------------------------------------------->  |                  |              |
 |                       +                      |                  |                     |                  |              |
 |                       |  <----------------------------------------------------------+ |                  |              |
 |                       |             data     |                  |                     |                  |              |
 |                       +                      |                  |                     |                  |              |
 |                   _process_data              |                  |                     |                  |              |
 |                  _try_put_index     next     |                  |                     |                  |              |
 |                  _next_index +-------------> |                  |                     |                  |              |
 |                       + <--------------------+                  |                     |                  |              |
 |                       |           index      |                  |                     |                  |              |
 |                       +---------------------------------------> |                     |       get        |              |
 | <-------------------+ |             put      |                  +------------------------------------->  |     index    |
 |        data           |                      |                  |                     |                  | +----------> |
 |                       |                      |                  |                     |                  +<-----------+ |
 v                       v                      v                  v                     v                  v     data     v

手機上如下:

img

2.5 Pipleline

至此,我們把之前的pipeline圖進一步細化,具體如下:

                                                  +------------+
                              +--------+          |            |
                              |        |          | Process 1  |
                      +-----> | Data 1 +--------> |            +------+
                      |       |        |          | Load Data  |      |
                      |       +--------+          |            |      |
                      |                           +------------+      |
                      |                                               |
                      |                                               |
                      |                                               |
+----------------+    |                           +------------+      |                                          +-------------------------+
|Main process    |    |       +--------+          |            |      |                                          |  pin_memory_thread      |
|                |    |       |        |          | Process 2  |      +------>  +------------------------+       |                         |          +------------+
|  index_queue   +----------> | Data 2 +--------> |            |                |                        |       |                         |          |            |
|                |    |       |        |          | Load Data  +------------->  |  _worker_result_queue  +-----> |  Write to pinned memory +--------> | data_queue |
|                |    |       +--------+          |            |                |                        |       |                         |          |            |
+----------------+    |                           +------------+       +----->  |                        |       |                         |          +------------+
                      |                                                |        +------------------------+       |                         |
                      |                                                |                                         +-------------------------+
                      |                                                |
                      |       +--------+          +------------+       |
                      |       |        |          |            |       |
                      +-----> | Data 3 +--------> | Process 3  +-------+
                              |        |          |            |
                              +--------+          | Load Data  |
                                                  |            |
                                                  +------------+

手機如下:

img

至此,PyTorch 分散式的資料載入部分分析完畢,下一篇我們迴歸到 Paracel 如何處理資料載入。

0xFF 參考

卷積神經網路的並行化模型--One weird trick for parallelizing convolutional neural networks

AI框架中資料處理的挑戰與解決思路

PyTorch 原始碼解讀之 torch.utils.data:解析資料處理全流程

談談你對大規模機器學習這個領域的理解和認識?

Nvidia-DALI 從放棄到入門

pytorch(分散式)資料並行個人實踐總結——DataParallel/DistributedDataParallel

Pytorch資料Pipeline設計總結

深度學習框架資料Pipeline設計

相關文章