[原始碼解析] PyTorch 分散式(2) --- 資料載入之DataLoader
0x00 摘要
為了更好的介紹引數伺服器Paracel的資料載入,我們臨時插入兩篇PyTorch的資料載入,主要是從分散式的角度進行切入。本文只算是開胃甜點,後續會有專門系列分析PyTorch分散式。
引數伺服器系列其他文章如下:
[原始碼解析] 機器學習引數伺服器ps-lite 之(1) ----- PostOffice
[原始碼解析] 機器學習引數伺服器ps-lite(2) ----- 通訊模組Van
[原始碼解析] 機器學習引數伺服器ps-lite 之(3) ----- 代理人Customer
[原始碼解析]機器學習引數伺服器ps-lite(4) ----- 應用節點實現
[原始碼解析] 機器學習引數伺服器 Paracel (1)-----總體架構
[原始碼解析] 機器學習引數伺服器 Paracel (2)--------SSP控制協議實現
[原始碼解析] PyTorch 分散式(1) --- 資料載入之DistributedSampler
0x01 前情回顧
關於資料載入,上回書我們說到了 DistributedSampler,本文接下來就進行 DataLoader的分析。
為了更好說明,我們首先給出上文的流水線圖,本文會對這個圖進行細化。
+------------+
+--------+ | |
| | | Process 1 |
| Data 1 +--------> | +------+
| | | Load Data | |
+--------+ | | |
+------------+ |
|
|
|
+------------+ | +-----------------------------------+
+--------+ | | | | |
| | | Process 2 | +------> | Pin-memory process |
| Data 2 +--------> | | | |
| | | Load Data +-------------> | |
+--------+ | | | Transfer to Pinned Memory |
+------------+ +-----> | |
| | |
| +-----------------------------------+
|
+--------+ +------------+ |
| | | | |
| Data 3 +--------> | Process 3 +-------+
| | | |
+--------+ | Load Data |
| |
+------------+
其次,我們再看看資料載入總體邏輯,具體如下圖,簡要說就是:
- DataSet 把資料集數目發給DistributedSampler。
- Sampler 按照某種規則生成資料indices併傳送給DataLoader。
- DataLoader 依據indices來從DataSet之中載入資料(其內部的DataLoaderIter物件負責協調單程式/多程式載入Dataset)。
- DataLoader 把資料發給模型,進行訓練。
+------------------------+ +-----------+
|DistributedSampler | |DataLoader |
| | 2 indices | |
| Some strategy +-------------------> | |
| | | |
|-------------+----------| | |
^ | | 4 data +-------+
| | -------------->+ train |
1 | length | | +-------+
| | |
+-------------+----------+ | |
|DataSet | | |
| +---------+ | 3 Load | |
| | Data +-------------------------> | |
| +---------+ | | |
| | | |
+------------------------+ +-----------+
接下來,我們就正式進入 DataLoader。
0x02 DataLoader
DataLoader的作用是:結合Dataset和Sampler之後,在資料集上提供了一個迭代器。
可以這麼理解:
DataSet 是原始資料,Sampler 提供瞭如何切分資料的策略(或者說是提供了切分資料的維度),DataLoader就是依據策略來具體打工幹活的,其中單程式載入就是一個人幹活,多程式載入就是多拉幾個人一起幹活。
2.1 初始化
初始化的主要引數如下:
- dataset (Dataset) :所載入的資料集。
- batch_size (int, optional) :每個批次載入多少個樣本。
- shuffle (bool, optional) :如果為 True,則每個epoch 都會再打亂資料。
- sampler (Sampler or Iterable, optional) :定義瞭如何從樣本取樣的策略。可以是任何實現了
__len__
的迭代器。 - batch_sampler (Sampler or Iterable, optional) :與
sampler
類似,但是每次返回一個批次的資料索引。 - num_workers (int, optional) :資料載入的子程式數目。如果是 0,表示從主程式載入資料。
- collate_fn (callable, optional):從一個小批次( mini-batch)張量中合併出一個樣本列表。當從 map-style 資料集做批量載入時候使用。
- pin_memory (bool, optional) : 如果為true,則在返回張量之前把張量拷貝到CUDA固定記憶體之中。
- drop_last (bool, optional) :當資料集不能被均勻分割時,如果為true,丟掉最後一個不完整的批次。如果為False,那麼最後一個批次的資料較小。
- timeout (numeric, optional): 如果是整數,則是worker收集批次資料的超時值。
- worker_init_fn (callable, optional):如果非空,則會在seeding和資料載入之前被每個子程式呼叫,以Iworker id (
[0, num_workers - 1]
)作為輸入引數。 - generator (torch.Generator, optional):如果非空,則被RandomSampler 用來產生隨機索引,也被多程式用來產生
base_seed
。 - prefetch_factor (int, optional, keyword-only arg):每個 worker 提前載入 的 sample 數量。
- persistent_workers (bool, optional):如果為
True
, 則在消費一次之後,data loader也 不會關掉worker程式。這允許workerDataset
例項維持活動狀態。
具體初始化程式碼如下,主要就是各種設定,為了更好的說明,去除了異常處理程式碼:
class DataLoader(Generic[T_co]):
dataset: Dataset[T_co]
batch_size: Optional[int]
num_workers: int
pin_memory: bool
drop_last: bool
timeout: float
sampler: Sampler
prefetch_factor: int
_iterator : Optional['_BaseDataLoaderIter']
__initialized = False
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
torch._C._log_api_usage_once("python.data_loader")
self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
self.pin_memory = pin_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.multiprocessing_context = multiprocessing_context
if isinstance(dataset, IterableDataset):
self._dataset_kind = _DatasetKind.Iterable
# 省略異常處理
else:
self._dataset_kind = _DatasetKind.Map
if batch_sampler is not None:
# auto_collation with custom batch_sampler
# 省略異常處理
batch_size = None
drop_last = False
elif batch_size is None:
# no auto_collation
if drop_last:
raise ValueError('batch_size=None option disables auto-batching '
'and is mutually exclusive with drop_last')
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset, generator=generator)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.drop_last = drop_last
self.sampler = sampler
self.batch_sampler = batch_sampler
self.generator = generator
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert
self.collate_fn = collate_fn
self.persistent_workers = persistent_workers
self.__initialized = True
self._IterableDataset_len_called = None
self._iterator = None
self.check_worker_number_rationality()
2.2 關鍵函式
這裡關鍵函式之一就是_index_sampler,用來讓迭代器呼叫sampler,我們接下來就會講到
@property
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
2.3 單程式載入
單程式模式下,Data Loader會在計算程式內載入資料,所以載入過程中可能會阻塞計算。
for 語句會呼叫enumerate 會返回一個迭代器,以此來遍歷資料集。在eumerate之中,dataloader 的 __next__(self)
方法會被呼叫,逐一獲取下一個物件,從而遍歷資料集。
cuda0 = torch.device('cuda:0') # CUDA GPU 0
for i, x in enumerate(train_loader):
x = x.to(cuda0)
2.3.1 區分生成
當多程式載入時候,在DataLoader宣告週期之中,迭代器只被建立一次,這樣worker可以重用迭代器。
在單程式載入時候,應該每次生成,以避免重置狀態。
def __iter__(self) -> '_BaseDataLoaderIter':
if self.persistent_workers and self.num_workers > 0: # 如果是多程式或者設定了持久化
if self._iterator is None: # 如果沒有,才會新生成
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else: # 單程式
return self._get_iterator() # 每次都直接生成新的
具體會依據是否是多程式來區別生成。
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
2.3.2 迭代器基類
_BaseDataLoaderIter 是迭代器基類,我們挑選關鍵函式看看。
這裡關鍵成員變數就是:
- _index_sampler:這裡設定了loader 的 sampler,所以迭代器可以據此獲取取樣策略。
- _sampler_iter:得到 sampler 的迭代器。
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
# 初始化引數
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler # 得到取樣策略
self._num_workers = loader.num_workers
self._prefetch_factor = loader.prefetch_factor
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler) # 得到sampler的迭代器
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # 獲取資料
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
# 忽略錯誤提示處理
warnings.warn(warn_msg)
return data
2.3.3 單程式迭代器
_SingleProcessDataLoaderIter
繼承了 _BaseDataLoaderIter
,可以看到,其增加了 _dataset_fetcher
,在構造時候傳入了 _collate_fn
等各種引數。
回憶下,__next__
會呼叫 self._next_data()
獲取資料,而在這裡,_next_data
就會:
- 使用
self._next_index()
,其又會使用_sampler_iter
(取樣器的迭代器)來獲取indices 。 - 使用
self._dataset_fetcher.fetch(index)
來依據indices獲取資料。
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
# 獲取樣本方法
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
# 獲取樣本
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
def _next_index(self): # 得到indices
return next(self._sampler_iter) # may raise StopIteration
2.3.4 獲取樣本
我們接下來看看如何獲取樣本。就是通過索引傳入 fetcher,從而獲取想要的樣本。
fetcher生成如下,這是在_SingleProcessDataLoaderIter初始化時候生成的:
class _DatasetKind(object):
Map = 0
Iterable = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
對於Map-style,就使用 _MapDatasetFetcher 處理,就是使用 possibly_batched_index 從資料集之中提取資料,possibly_batched_index 是key。
如果有batch sampler,就使用 batch sampler。
如果需要從一個小批次( mini-batch)張量中合併出一個樣本列表。就使用 collate_fn後處理。
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# 如果配置了batch_sampler,_auto_collation就為True,
# 那麼就優先使用batch_sampler,此時fetcher中傳入的就是一個batch的索引
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
對於 Iterable-style,因為 __init__
方法內設定了 dataset 初始的迭代器,所以在fetch 方法內獲取元素的時候,如果是常規 sampler,index 其實已經不起作用,直接從dataset迭代器獲取。如果是batch sampler,則index有效果。
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
def fetch(self, possibly_batched_index):
if self.auto_collation:
# 即auto_collation為True,表示使用batch_sampler。
# 則使用possibly_batched_index,獲取1個batch大小的樣本
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
break
if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
raise StopIteration
else:
# sampler則直接往後遍歷,提取1個樣本
data = next(self.dataset_iter)
return self.collate_fn(data)
此時總邏輯如下:
+--------------------------+ +-------------------------------+
| DataLoader | | _SingleProcessDataLoaderIter |
| | | |
| | | __next__ |
+---------------+ Sampler | | |
| | | | _next_data +-----------+
| | Dataset | | | |
| | | | _next_index | |
| | __iter__ | | | |
| | | | _index_sampler | |
| | _get_iterator +--------------> | + | |
| | | | | | |
| +--------------------------+ +-------------------------------+ |
| | |
| | |
| | |
| | |
| | |
| +----------------------------+ | |
| |Sampler | | |
+------------------------> | | <------+ |
| | |
| | |
| | |
+----------------------------+ |
|
|
+----------------------------+ |
|_BaseDatasetFetcher | |
| | |
| | |
| dataset | |
| | <----------------------+
| collate_fn |
| |
+----------------------------+
動態流程如下:
User DataLoader _SingleProcessDataLoaderIter _DatasetKind Sampler
+ + + + +
| | | | |
| 1 | | | |
enumerate--------> __iter__ | | |
| + | | |
| | | | |
| | | | |
| | 2 v 3 v |
| _get_iterator--------> __init__ +----------> create_fetcher |
| 4 | + + |
| <-----------------+ | | |
| iterator | | | |
| | 5 | | |
for loop +------------------------------> __next__ | |
| | | | |
| | | | |
| | | | |
| | _next_data | |
| | | | |
| | | | |
| | | 6 next | |
| | _next_index +-------------------------> |
| | | | |
| | | <---------------------------------+
| | | 7 index | |
| | | | |
| | | | |
| | | 8 fetch(index) | |
| | | +--------------------> | |
| | | | |
| | | <---------------------+ |
| | | 9 data | |
| <-------------------------------------+ | |
| 10 data | | | |
| | | | |
v v v v v
2.4 多程式載入
為了加速,PyTorch提供了多程式下載,只要把將引數 num_workers
設定為正整數,系統就會相應生成多程式處理,在這種模式下,每個worker都是一個獨立程式。
由上節我們可以知道,_SingleProcessDataLoaderIter 是單程式載入資料的核心,loader通過它來與sampler,dataset互動。在多程式中,這個核心對應的就是 _MultiProcessingDataLoaderIter。
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
我們接下來就從 _MultiProcessingDataLoaderIter 開始分析。
2.4.1 總體邏輯
_MultiProcessingDataLoaderIter 中的註釋十分詳盡,值得大家深讀,而且給出了邏輯流程圖如下,其基本流程是圍繞著三個queue進行的:
- 主程式把需要獲取的資料 index 放入index_queue,這是指定子程式需要獲取哪些資料的佇列。同時也給子程式傳入結果佇列,關於結果佇列,有兩個分支:
- 如果設定了pin memory,則傳入的是 worker_result_queue。
- 否則傳入 data_queue。
- 子程式從 index_queue 之中讀取 index,進行資料讀取,然後把讀取資料的index放入worker_result_queue,這是向主程式返回結果的佇列。
- 主程式進行處理,這裡有兩個分支:
- 如果設定了pin memory,則主程式的 pin_memory_thread 會從 worker_result_queue 讀取資料index,依據這個index進行讀取資料,進行處理,把結果放入 data_queue,這是處理結果的佇列。
- 如果不需要pin memory,則結果已經存在 data_queue 之中,不做新操作。
可以看到,每個程式的輸入是一個佇列index_queue ,輸出也是一個佇列worker_result_queue。主程式和子程式通過這2~3個 queue 聯絡了起來,從而達到解耦合和加速的作用。
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
#
# Preliminary:
#
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.
具體如下圖所示,如果不需要 pin memory,則為:
+-----------+
indices -------------+ indices | Worker | Data
+--------->+index queue +-------->+ Process +------+
| | | | | |
| -------------+ +-----------+ |
| | +------------+
| | | |
+---------+ | +---> |
| Main | | indices -------------+ indices +-----------+ | |
| Process +------------>+index queue +-------->+ Worker | Data | Data Queue |
| | | | | | Process +----------> |
+---------+ | -------------+ | | | |
| +-----------+ +---> |
| | +------------+
| |
| indices -------------+ indices +-----------+ |
+--------->+index queue +-------->+ Worker | Data |
| | | Process +------+
-------------+ | |
+-----------+
當有pin memory時候,則是先進入 result queue,然後 pin_memory_thread 處理之後會轉入到 data queue:
+-----------+
indices -------------+ indices | Worker | Data
+--------->+index queue +-------->+ Process +------+
| | | | | |
| -------------+ +-----------+ |
| | --------------+
| | | |
+---------+ | +---> |
| Main | | indices -------------+ indices +-----------+ | |
| Process +------------>+index queue +-------->+ Worker | Data | result_queue|
| | | | | | Process +----------> |
+---------+ | -------------+ | | | |
| +-----------+ +---> |
| | ---------+----+
| | |
| indices -------------+ indices +-----------+ | |
+--------->+index queue +-------->+ Worker | Data | +---------+--------+
| | | Process +------+ | pin_memory_thread|
-------------+ | | | | |
+-----------+ | | |
| | |
+------------------+
|
|
|
v
+-----+------+
| Data Queue |
| |
+------------+
2.4.2 初始化
初始化函式如下,主要是:
- 配置,生成各種成員變數,配置各種queue。
- 啟動各個子程式。
- 啟動主程式中的pin_memory的執行緒。
主要成員變數為:
_index_queues
: 這是一個queue 列表,列表的每一個元素是一個 queue,就是每個子程式的佇列需要處理的資料index,每個子程式對應一個 queue。_worker_result_queue
: 子程式處理完的 (idx, data)。data_queue
: 經過主程式 pin_memory 執行緒處理之後的資料佇列,如果不需要pin,則直接會使用_worker_result_queue
。_worker_queue_idx_cycle
用以找出下一個工作的worker。
具體程式碼如下:
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
def __init__(self, loader):
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
assert self._num_workers > 0
assert self._prefetch_factor > 0
if loader.multiprocessing_context is None:
multiprocessing_context = multiprocessing
else:
multiprocessing_context = loader.multiprocessing_context
self._worker_init_fn = loader.worker_init_fn
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
# No certainty which module multiprocessing_context is
self._worker_result_queue = multiprocessing_context.Queue() # 子程式輸出,讀取完資料的index
self._worker_pids_set = False
self._shutdown = False
self._workers_done_event = multiprocessing_context.Event()
self._index_queues = [] # 子程式輸入,需讀取資料的index
self._workers = []
for i in range(self._num_workers):
# No certainty which module multiprocessing_context is
index_queue = multiprocessing_context.Queue() # type: ignore[var-annotated]
# Need to `cancel_join_thread` here!
# See sections (2) and (3b) above.
index_queue.cancel_join_thread()
w = multiprocessing_context.Process(
target=_utils.worker._worker_loop, # worker程式主函式,把各種queue和函式傳進去
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
w.daemon = True
w.start()
self._index_queues.append(index_queue) # 把這個worker對應的index_queue放到主程式這裡存起來,以後就可以互動了
self._workers.append(w)
if self._pin_memory:
self._pin_memory_thread_done_event = threading.Event()
# Queue is not type-annotated
self._data_queue = queue.Queue() # pin 處理之後的資料結果
pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self._worker_result_queue, self._data_queue,
torch.cuda.current_device(),
self._pin_memory_thread_done_event))
pin_memory_thread.daemon = True
pin_memory_thread.start()
# Similar to workers (see comment above), we only register
# pin_memory_thread once it is started.
self._pin_memory_thread = pin_memory_thread
else:
self._data_queue = self._worker_result_queue # 如果不需要pin,則直接使用_worker_result_queue
# .pid can be None only before process is spawned (not the case, so ignore)
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers)) # type: ignore[misc]
_utils.signal_handling._set_SIGCHLD_handler()
self._worker_pids_set = True
self._reset(loader, first_iter=True) # 繼續完善業務
2.4.3 業務重置
__init__
函式最後會呼叫 _reset 函式,這是進一步完善業務初始化,也用來重置環境。
上小節函式中,已經啟動了worker子程式,但是沒有分配任務,所以_reset函式會進行任務分配,預取。
_MultiProcessingDataLoaderIter有如下 flag 引數來協調各個 worker (包括各種queue)之間的工作:
-
_send_idx
: 傳送索引,用來記錄這次要放 index_queue 中 batch 的 idx -
_rcvd_idx
: 接受索引,記錄要從 data_queue 中取出的 batch 的 idx -
_task_info
: 儲存將要產生的 data 資訊的 dict,key為 task idx(由 0 開始的整型索引),value 為(worker_id,)
或(worker_id, data)
,分別對應資料 未取 和 已取 的情況 -
_tasks_outstanding
: 整型,代表已經準備好的 task/batch 的數量(可能有些正在準備中) -
_send_idx
: 傳送索引,記錄下一次要放 index_queue 中 task batch 的 idx。 -
_rcvd_idx
: 接受索引,記錄下一次要從 data_queue 中取出的 task batch 的 idx。_send_idx
和_rcvd_idx
主要用來進行流量控制和確保接受索引有意義。 -
_task_info
: 儲存將要產生的 data 資訊的 dict,key為 task batch idx(由 0 開始的整型索引),value 為(worker_id,)
或(worker_id, data)
,分別對應資料 未取 和 已取 的情況。_task_info
的作用是依據 task batch idx 獲取對應的 worker id 和暫存亂序資料。 -
_tasks_outstanding
: 整型,正在準備的 task/batch 的數量,實際上就是進行一些確認工作,沒有太實際的意義。
對於載入資料,每個 worker 一次產生一個 batch 的資料,返回 batch 資料前,會放入下一個批次要處理的資料下標,所以 reset 函式會把 _send_idx
,_rcvd_idx
都恢復成0,這樣下次迭代就可以重新處理。
在 reset 方法最後,有一個預取資料操作。我們會在後面結合亂序處理進行講解。
def _reset(self, loader, first_iter=False):
super()._reset(loader, first_iter)
self._send_idx = 0 # idx of the next task to be sent to workers
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
# \ (worker_id, data) if data is already fetched (out-of-order)
self._task_info = {}
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
# A list of booleans representing whether each worker still has work to
# do, i.e., not having exhausted its iterable dataset object. It always
# contains all `True`s if not using an iterable-style dataset
# (i.e., if kind != Iterable).
# Not that this indicates that a worker still has work to do *for this epoch*.
# It does not mean that a worker is dead. In case of `_persistent_workers`,
# the worker will be reset to available in the next epoch.
# 每個worker的狀態
self._workers_status = [True for i in range(self._num_workers)]
# We resume the prefetching in case it was enabled
if not first_iter:
for idx in range(self._num_workers):
self._index_queues[idx].put(_utils.worker._ResumeIteration())
resume_iteration_cnt = self._num_workers
while resume_iteration_cnt > 0:
return_idx, return_data = self._get_data()
if isinstance(return_idx, _utils.worker._ResumeIteration):
assert return_data is None
resume_iteration_cnt -= 1
# prime the prefetch loop
# 預取若干index,目的是為了配合後續的亂序處理。
for _ in range(self._prefetch_factor * self._num_workers):
self._try_put_index()
2.4.4 獲取 index
_try_put_index 函式就是使用sampler獲取下一批次的資料index。這裡 _prefetch_factor 預設值是 2,主要邏輯如下。
- 從sampler獲取下一批次的index。
- 通過 _worker_queue_idx_cycle 找出下一個可用的工作worker,然後把index分給它。
- 並且調整主程式的資訊。
def _next_index(self): # 定義在基類 _BaseDataLoaderIter 之中,就是獲取下一批index
return next(self._sampler_iter) # may raise StopIteration
def _try_put_index(self):
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
try:
index = self._next_index() # 獲取下一批index
except StopIteration:
return
for _ in range(self._num_workers): # find the next active worker, if any
worker_queue_idx = next(self._worker_queue_idx_cycle)
if self._workers_status[worker_queue_idx]: # 如果已經工作,就繼續找
break
else:
# not found (i.e., didn't break)
return
# 以下是主程式進行相關記錄
# 給下一個工作worker放入 (任務index, 資料index), 就是給queue放入資料,所以worker loop之中就立刻會從queue中得到index,從而開始獲取資料。
self._index_queues[worker_queue_idx].put((self._send_idx, index))
# 記錄 將要產生的 data 資訊
self._task_info[self._send_idx] = (worker_queue_idx,)
# 正在處理的batch個數+1
self._tasks_outstanding += 1
# send_idx 記錄從sample_iter中傳送索引到index_queue的次數
self._send_idx += 1 # 遞增下一批傳送的task index
2.4.5 worker主函式
_worker_loop 是 worker程式的主函式,主要邏輯如其註釋所示:
# [ worker processes ]
# While loader process is alive:
# Get from `index_queue`.
# If get anything else,
# Check `workers_done_event`.
# If set, continue to next iteration
# i.e., keep getting until see the `None`, then exit.
# Otherwise, process data:
# If is fetching from an `IterableDataset` and the iterator
# is exhausted, send an `_IterableDatasetStopIteration`
# object to signal iteration end. The main process, upon
# receiving such an object, will send `None` to this
# worker and not use the corresponding `index_queue`
# anymore.
# If timed out,
# No matter `workers_done_event` is set (still need to see `None`)
# or not, must continue to next iteration.
# (outside loop)
# If `workers_done_event` is set, (this can be False with `IterableDataset`)
# `data_queue.cancel_join_thread()`. (Everything is ending here:
# main process won't read from it;
# other workers will also call
# `cancel_join_thread`.)
就是通過index_queue, data_queue與主程式互動。
- 從 index_queue 獲取新的資料index;
- 如果沒有設定本worker結束,就使用 fetcher獲取資料。
- 然後把資料放入data_queue,並且通知主程式,這裡需要注意,data_queue是傳入的引數,如果設定了pin memory,則傳入的是 worker_result_queue, 否則傳入 data_queue。
def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
num_workers, persistent_workers):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
try:
# Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal had already happened
# again.
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
signal_handling._set_worker_signal_handlers()
torch.set_num_threads(1)
seed = base_seed + worker_id
random.seed(seed)
torch.manual_seed(seed)
if HAS_NUMPY:
np_seed = _generate_state(base_seed, worker_id)
import numpy as np
np.random.seed(np_seed)
global _worker_info
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
seed=seed, dataset=dataset)
from torch.utils.data import _DatasetKind
init_exception = None
try:
if init_fn is not None:
init_fn(worker_id)
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
except Exception:
init_exception = ExceptionWrapper(
where="in DataLoader worker process {}".format(worker_id))
iteration_end = False
watchdog = ManagerWatchdog()
while watchdog.is_alive(): # 等待在這裡
try:
# _try_put_index 如果放入了資料index,這裡就被啟用,開始工作
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
if isinstance(r, _ResumeIteration):
# Acknowledge the main process
data_queue.put((r, None))
iteration_end = False
# Recreate the fetcher for worker-reuse policy
fetcher = _DatasetKind.create_fetcher(
dataset_kind, dataset, auto_collation, collate_fn, drop_last)
continue
elif r is None:
# Received the final signal
assert done_event.is_set() or iteration_end
break
elif done_event.is_set() or iteration_end:
# `done_event` is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, index = r
data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
if init_exception is not None:
data = init_exception
init_exception = None
else:
try:
data = fetcher.fetch(index)
except Exception as e:
# 省略處理程式碼
data_queue.put((idx, data)) # 放入資料,通知主程式
del data, idx, index, r # save memory
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass
if done_event.is_set():
data_queue.cancel_join_thread()
data_queue.close()
2.4.6 Pin memory thread
在主程式之中,如果設定了需要pin memory,主程式的 pin_memory_thread 會從 worker_result_queue 讀取資料,進行處理(加速CPU和GPU的資料拷貝),把結果放入 data_queue。
# [ pin_memory_thread ]
# # No need to check main thread. If this thread is alive, the main loader
# # thread must be alive, because this thread is set as daemonic.
# While `pin_memory_thread_done_event` is not set:
# Get from `index_queue`.
# If timed out, continue to get in the next iteration.
# Otherwise, process data.
# While `pin_memory_thread_done_event` is not set:
# Put processed data to `data_queue` (a `queue.Queue` with blocking put)
# If timed out, continue to put in the next iteration.
# Otherwise, break, i.e., continuing to the out loop.
#
# NOTE: we don't check the status of the main thread because
# 1. if the process is killed by fatal signal, `pin_memory_thread`
# ends.
# 2. in other cases, either the cleaning-up in __del__ or the
# automatic exit of daemonic thread will take care of it.
# This won't busy-wait either because `.get(timeout)` does not
# busy-wait.
具體程式碼如下:
def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
# This setting is thread local, and prevents the copy in pin_memory from
# consuming all CPU cores.
torch.set_num_threads(1)
torch.cuda.set_device(device_id)
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
while not done_event.is_set():
try:
r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
except queue.Empty:
continue
idx, data = r
if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
data = pin_memory(data)
# 省略異常處理程式碼
r = (idx, data)
while not done_event.is_set():
try:
out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
break
except queue.Full:
continue
del r # save memory
def pin_memory(data):
if isinstance(data, torch.Tensor):
return data.pin_memory()
elif isinstance(data, string_classes):
return data
elif isinstance(data, collections.abc.Mapping):
return {k: pin_memory(sample) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return type(data)(*(pin_memory(sample) for sample in data))
elif isinstance(data, collections.abc.Sequence):
return [pin_memory(sample) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
return data
2.4.7 使用者獲取data
現在資料已經載入完畢,我們接下來看使用者如何從DataLoader之中獲取資料。
這裡有一個很關鍵的地方:如何保持在不同實驗之中資料讀取順序的一致性。為了讓多次實驗之間可以比對,就需要儘量保證在這些實驗中,每次讀取資料的順序都是一致的,這樣才不會因為資料原因造成結果的誤差。
打破順序一致性的最大可能就是亂序資料。而造成亂序問題的原因就是:多程式讀取,可能某個程式快,某個程式慢。比如,使用者這次需要讀取6-19,16-26,37-46。但是某一個worker慢,6-19不能即時返回,另一個worker 的 16-26 先返回了,於是就會造成亂序。
如何處理亂序資料?PyTorch的具體做法就是:DataLoader嚴格按照Sampler的順序返回資料。如果某一個資料是亂序的,則會把它暫存起來,轉而去獲取下一個資料,見下面程式碼中 "store out-of-order samples" 註釋處。等到應該返回時候(這個資料順序到了)才返回。
但是其風險就是資料返回會比當前請求慢,比如應該獲取 6,但是Data queue裡面沒有這個資料,只有 16,27,於是使用者只能等待 6 載入完成。
解決慢的方法是:預取(prefetch)。就是在reset方法最後,提前提取若干index,讓DataLoader提前去取,這雖然不能保證任意兩次訓練的資料返回順序完全一致,但是可以最大限度保證。
具體程式碼如下,首先,回憶基類的 __next__
函式 ,可以看到其呼叫了 _next_data 獲取資料。
class _BaseDataLoaderIter(object):
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data() # 獲取資料
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
# 忽略錯誤提示處理
warnings.warn(warn_msg)
return data
所以,我們要看 _MultiProcessingDataLoaderIter
的_next_data
。
- 因為之前有預取了index,worker程式已經開始獲取資料,所以主程式此時可以得到資料,如果沒有資料,就繼續while True等待。
- 如果獲取成功,則使用
_process_data
設定下一次的indx,準備下一次迭代。 - 通過
_task_info
來記錄亂序資料,如果暫時無法處理,就在這裡儲存。
def _next_data(self):
while True:
# If the worker responsible for `self._rcvd_idx` has already ended
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
# we try to advance `self._rcvd_idx` to find the next valid index.
#
# This part needs to run in the loop because both the `self._get_data()`
# call and `_IterableDatasetStopIteration` check below can mark
# extra worker(s) as dead.
# 找到待取idx
while self._rcvd_idx < self._send_idx: # 如果 待取batch idx < 已取batch idx
info = self._task_info[self._rcvd_idx]
worker_id = info[0]
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
break # 有資料或者正在工作,就跳出內部這個while
del self._task_info[self._rcvd_idx]
self._rcvd_idx += 1
else:
# no valid `self._rcvd_idx` is found (i.e., didn't break)
if not self._persistent_workers:
self._shutdown_workers()
raise StopIteration
# Now `self._rcvd_idx` is the batch index we want to fetch
# Check if the next sample has already been generated
if len(self._task_info[self._rcvd_idx]) == 2:
data = self._task_info.pop(self._rcvd_idx)[1]
return self._process_data(data) # 設定下一次的indx,進行下一次迭代
assert not self._shutdown and self._tasks_outstanding > 0
idx, data = self._get_data() # 從 self._data_queue 中取資料
self._tasks_outstanding -= 1 # 正在準備的batch個數需要減1
if self._dataset_kind == _DatasetKind.Iterable:
# Check for _IterableDatasetStopIteration
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
if self._persistent_workers:
self._workers_status[data.worker_id] = False
else:
self._mark_worker_as_unavailable(data.worker_id)
self._try_put_index()
continue
if idx != self._rcvd_idx: # 亂序資料
# store out-of-order samples
self._task_info[idx] += (data,)
else:
del self._task_info[idx] # 正常資料
return self._process_data(data) # 設定下一次的indx,進行下一次迭代
其次,我們看看 _get_data
如何從 self._data_queue
中取資料。具體是使用 _try_get_data 來提取。
- 如果有超時配置,就按照超時讀取。
- 如果設定了pin memory,則從pin 執行緒處理之後的資料讀取。
- 否則迴圈讀取worker處理的資料,直至獲取到資料為止。
def _get_data(self):
# Fetches data from `self._data_queue`.
#
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
# in a loop. This is the only mechanism to detect worker failures for
# Windows. For other platforms, a SIGCHLD handler is also used for
# worker failure detection.
#
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
# died at timeouts.
if self._timeout > 0: # 如果有超時配置,就按照超時讀取
success, data = self._try_get_data(self._timeout)
if success:
return data
else:
raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
elif self._pin_memory: # 從pin 執行緒處理之後的資料讀取
while self._pin_memory_thread.is_alive():
success, data = self._try_get_data()
if success:
return data
else:
# while condition is false, i.e., pin_memory_thread died.
raise RuntimeError('Pin memory thread exited unexpectedly')
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
# need to call `.task_done()` because we don't use `.join()`.
else:
while True:
success, data = self._try_get_data() # 讀取worker處理的資料
if success:
return data
_try_get_data
就是從 _data_queue
讀取。主程式和worker程式通過queue上的put, get進行通訊互動。
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
# Tries to fetch data from `self._data_queue` once for a given timeout.
# This can also be used as inner loop of fetching without timeout, with
# the sender status as the loop condition.
#
# This raises a `RuntimeError` if any worker died expectedly. This error
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
# (only for non-Windows platforms), or the manual check below on errors
# and timeouts.
#
# Returns a 2-tuple:
# (bool: whether successfully get data, any: data if successful else None)
try:
data = self._data_queue.get(timeout=timeout)
return (True, data)
except Exception as e:
# At timeout and error, we manually check whether any worker has
# failed. Note that this is the only mechanism for Windows to detect
# worker failures.
failed_workers = []
for worker_id, w in enumerate(self._workers):
if self._workers_status[worker_id] and not w.is_alive():
failed_workers.append(w)
self._mark_worker_as_unavailable(worker_id)
# 省略異常處理程式碼
import tempfile
import errno
try:
# Raise an exception if we are this close to the FDs limit.
# Apparently, trying to open only one file is not a sufficient
# test.
# See NOTE [ DataLoader on Linux and open files limit ]
fds_limit_margin = 10
fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
except OSError as e:
# 省略異常處理程式碼
raise
設定下一次迭代是使用_process_data
。
def _process_data(self, data):
self._rcvd_idx += 1
self._try_put_index() # 設定下一次的indx,進行下一次迭代
if isinstance(data, ExceptionWrapper):
data.reraise()
return data # 返回資料
2.4.8 小結
我們小結一下多程式邏輯。
總體邏輯如下:
- 主程式把需要獲取的資料 index 放入index_queue。
- 子程式從 index_queue 之中讀取 index,進行資料讀取,然後把讀取資料的index放入worker_result_queue。
- 主程式的 pin_memory_thread 會從 worker_result_queue 讀取資料index,依據這個index進行讀取資料,進行處理,把結果放入 data_queue。
具體流程如下圖:
- 在 _MultiProcessingDataLoaderIter 的初始化函式
__init__
之中會進行初始化:- 配置,生成各種成員變數,配置各種queue。
- 啟動各個子程式。
- 啟動主程式中的pin_memory的執行緒。
- 呼叫 _reset 函式,這是進一步完善業務初始化,也用來重置環境。上面已經啟動了worker子程式,但是沒有分配任務,所以reset函式會進行任務分配,預取。
- 接下來是一個預取操作(在看下圖中一定要留意)。
- _try_put_index 函式就是使用sampler獲取下一批次的資料index。這裡 _prefetch_factor 預設值是 2,主要邏輯如下。
- 使用 _next_index 從sampler獲取下一批次的index。
- 通過 _worker_queue_idx_cycle 找出下一個可用的工作worker,然後把index分給它。
- 並且調整主程式的資訊。
- 拿到index之後,回到主執行緒。這裡會進行資料提取。就是通過index_queue, data_queue與主程式互動。
- 從 index_queue 獲取新的資料index;
- 如果沒有設定本worker結束,就使用 fetcher獲取資料。
- 然後把資料放入data_queue,並且通知主程式,這裡需要注意,data_queue是傳入的引數,如果設定了pin memory,則傳入的是 worker_result_queue,否則傳入 data_queue。
- _try_put_index 函式就是使用sampler獲取下一批次的資料index。這裡 _prefetch_factor 預設值是 2,主要邏輯如下。
- 當使用者迭代時,呼叫了Loader基類的
__next__
函式 ,其呼叫 _next_data 從 DataLoader 之中獲取資料。- 使用
_get_data
如何從self._data_queue
中取資料。 - 使用
_process_data
設定下一次迭代的 index,即使用_try_put_index
,_next_index
來進行下一輪設定。
- 使用
具體如下圖:
user _MultiProcessingDataLoaderIter Sampler Queue(index_queue) Queue(data_queue) _worker_loop Fetcher
+ + + + + + +
| | | | | | |
| | | | | | |
| v | | | | |
| __init__ | | | | |
| 1 _reset | | | | |
| + | | | | |
| | | | | | |
| | | | | | |
| v | | | | |
| 2 _try_put_index next | | | | |
| _next_index +------------> | | | | |
| + | | | | |
| | <-----------------+ | | | | |
| | index | | | | |
| | | | | | |
| | +------------------------------------> | | | |
| | put | | | get | |
| | | +--------------------------------------> | |
| | | | | | index |
| | | | | +------------> |
| next | | | | | <----------+ |
+---------------------> | | | | <----------------+ data |
| | | | | data | |
| + | | | | |
| _next_data | | | | |
| 3 _get_data get | | | | |
| _try_get_data +--------------------------------------------------> | | |
| + | | | | |
| | <----------------------------------------------------------+ | | |
| | data | | | | |
| + | | | | |
| _process_data | | | | |
| _try_put_index next | | | | |
| _next_index +-------------> | | | | |
| + <--------------------+ | | | |
| | index | | | | |
| +---------------------------------------> | | get | |
| <-------------------+ | put | +-------------------------------------> | index |
| data | | | | | +----------> |
| | | | | +<-----------+ |
v v v v v v data v
手機上如下:
2.5 Pipleline
至此,我們把之前的pipeline圖進一步細化,具體如下:
+------------+
+--------+ | |
| | | Process 1 |
+-----> | Data 1 +--------> | +------+
| | | | Load Data | |
| +--------+ | | |
| +------------+ |
| |
| |
| |
+----------------+ | +------------+ | +-------------------------+
|Main process | | +--------+ | | | | pin_memory_thread |
| | | | | | Process 2 | +------> +------------------------+ | | +------------+
| index_queue +----------> | Data 2 +--------> | | | | | | | |
| | | | | | Load Data +-------------> | _worker_result_queue +-----> | Write to pinned memory +--------> | data_queue |
| | | +--------+ | | | | | | | |
+----------------+ | +------------+ +-----> | | | | +------------+
| | +------------------------+ | |
| | +-------------------------+
| |
| +--------+ +------------+ |
| | | | | |
+-----> | Data 3 +--------> | Process 3 +-------+
| | | |
+--------+ | Load Data |
| |
+------------+
手機如下:
至此,PyTorch 分散式的資料載入部分分析完畢,下一篇我們迴歸到 Paracel 如何處理資料載入。
0xFF 參考
卷積神經網路的並行化模型--One weird trick for parallelizing convolutional neural networks
PyTorch 原始碼解讀之 torch.utils.data:解析資料處理全流程
pytorch(分散式)資料並行個人實踐總結——DataParallel/DistributedDataParallel