PyTorch之分散式操作Barrier

lart 發表於 2022-01-23
PyTorch

PyTorch之分散式操作Barrier

原始文件:https://www.yuque.com/lart/ug...

關於 barrier 的概念

關於 barrier 這個概念可以參考 Wiki 中的介紹:同步屏障(Barrier)是平行計算中的一種同步方法。對於一群程式或執行緒,程式中的一個同步屏障意味著任何執行緒/程式執行到此後必須等待,直到所有執行緒/程式都到達此點才可繼續執行下文。

這裡要注意,barrier 這一方法並不是 pytorch 獨有的,這是平行計算中的一個基本概念,其他的平行計算的場景下也可能會涉及這一概念和操作。本文主要討論 pytorch 中的情況。

torch.distributed.barrier(group=None, async_op=False, device_ids=None)

Synchronizes all processes.

This collective blocks processes until the whole group enters this function, if async_op is False, or if async work handle is called on wait().

Parameters
group (ProcessGroup, optional) – The process group to work on. If None, the default process group will be used.
async_op (bool, optional) – Whether this op should be an async op
device_ids ([int], optional) – List of device/GPU ids. Valid only for NCCL backend.

Returns
Async work handle, if async_op is set to True. None, if not async_op or if not part of the group

在多卡訓練的時候,由於不同的 GPU 往往被設定在不同的程式中,有時候為了在單獨的程式中執行一些任務,但是又同時希望限制其他程式的執行進度,就有了使用barrier的需求。
一個實際的場景是準備資料集:我們只需要在 0 號程式處理,其他程式沒必要也執行這一任務,但是其他程式的後續工作卻依賴準備好的資料。於是就需要在 0 號程式執行過程中阻塞其他的程式,使其進入等待狀態。等到處理好之後,再一起放行。

這種需求下,一個典型的基於上下文管理器形式的構造如下:

# https://github.com/ultralytics/yolov5/blob/7d56d451241e94cd9dbe4fcb9bfba0e92c6e0e23/utils/torch_utils.py#L29-L38

@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
    Decorator to make all processes in distributed training
    wait for each local_master to do something.
    """
    if local_rank not in [-1, 0]:
        dist.barrier(device_ids=[local_rank])
    yield
    if local_rank == 0:
        dist.barrier(device_ids=[0])

關於 barrier 的細節

# -*- coding: utf-8 -*-

import os
import time

import torch.distributed as dist
import torch.multiprocessing as mp


def ddp_test_v0(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    print("first before barrier{}\n".format(local_rank))
    if local_rank != 0:
        dist.barrier()
    print("first after barrier{}\n".format(local_rank))

    print("inter {}".format(local_rank))

    print("second before barrier{}\n".format(local_rank))
    if local_rank == 0:
        dist.barrier()
    print("second after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))


def ddp_test_v1(local_rank, word_size):
    # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    if local_rank != 0:
        print("1 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(5)
        dist.barrier()
        print(time.time() - start)
        print("1 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("1 after barrier{}\n".format(local_rank))
    else:
        print("0 before barrier{}\n".format(local_rank))
        start = time.time()
        dist.barrier()
        print(time.time() - start)
        print("0 after barrier{}\n".format(local_rank))
        print("0 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("0 after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))


def main():
    world_size = 2
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    mp.spawn(ddp_test_v0, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main()

這裡展示了兩個例子,實際上在官方展示的 dist.barrier  之外顯示了該方法的一個重要特性,就是其操作實際上是每一個程式內部都需要對應的執行同樣的次數,才會對應的由阻塞變為正常執行。
先看第一個例子:

def ddp_test(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    print("first before barrier{}\n".format(local_rank))
    if local_rank != 0:
        dist.barrier()
    print("first after barrier{}\n".format(local_rank))

    print("inter {}".format(local_rank))

    print("second before barrier{}\n".format(local_rank))
    if local_rank == 0:
        dist.barrier()
    print("second after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))

其輸出是:

first before barrier1
first before barrier0


first after barrier0

inter 0
second before barrier0

second after barrier0

0 exit
first after barrier1

inter 1
second before barrier1

second after barrier1

1 exit

Process finished with exit code 0

可以看到,有幾個細節:

  • barrier  之前,所有的操作都是各 GPU 程式自己輸出自己的。

    • 由於 local_rank=0  執行到自己可見的 barrier  中間會輸出多個,而 local_rank=1  則只有一條 first before barrier1 。
  • second before barrier0  之後,0 號執行到了屬於自己的 barrier ,這回讓使得其他程式不再阻塞,開始正常執行。由於中間操作的時間,所以先是 0 號輸出自己的 second after barrier0  並隨之退出,之後 1 號也接著開始輸出自己的結果。

這裡有一點值得注意,不同程式的 barrier  實際上是互相對應的,必須所有程式都執行一次barrier,才會重新放行正常前進。
對於第二段程式碼:

def ddp_test_v1(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    if local_rank != 0:
        print("1 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(5)
        dist.barrier()
        print(time.time() - start)
        print("1 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("1 after barrier{}\n".format(local_rank))
    else:
        print("0 before barrier{}\n".format(local_rank))
        start = time.time()
        dist.barrier()
        print(time.time() - start)
        print("0 after barrier{}\n".format(local_rank))
        print("0 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("0 after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))

則是有輸出:

1 before barrier1
0 before barrier0


5.002117395401001
5.0021262168884281 after barrier1


0 after barrier0

0 after barrier0

0 after barrier0

0 exit
1 after barrier1

1 exit

Process finished with exit code 0

可以看到一個重要的點,就是這兩處 print(time.time() - start)  的輸出是基本一樣的,不管前面延時多少, barrier  後面的時間都是按照最長到達並執行 barrier  的間隔時間來的。這個更體現了不同程式 barrier  之間的互相限制關係。而 0 到達自己的第二個 barrier  之後,會使得 1 號再次執行。但是此時 0 是先結束的。
另外,可以驗證,如果某個編號對應的程式碼中的兩個 barrier  之中的一個,那麼另一個就會陷入無限等待之中。
例如:


def ddp_test_v1(local_rank, word_size):
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    dist.init_process_group(backend="nccl", world_size=word_size, rank=local_rank)

    if local_rank != 0:
        print("1 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(5)
        dist.barrier()
        print(time.time() - start)
        print("1 after barrier{}\n".format(local_rank))
        # dist.barrier()
        print("1 after barrier{}\n".format(local_rank))
    else:
        print("0 before barrier{}\n".format(local_rank))
        start = time.time()
        time.sleep(3)
        dist.barrier()
        print(time.time() - start)
        print("0 after barrier{}\n".format(local_rank))
        print("0 after barrier{}\n".format(local_rank))
        dist.barrier()
        print("0 after barrier{}\n".format(local_rank))

    print("{} exit".format(local_rank))

輸出:

0 before barrier0
1 before barrier1


5.002458572387695
1 after barrier1

1 after barrier1

1 exit
5.002473831176758
0 after barrier0

0 after barrier0

Traceback (most recent call last):
  File "/home/lart/Coding/SODBetterProj/tools/dist_experiment_test.py", line 67, in <module>
    main()
  File "/home/lart/Coding/SODBetterProj/tools/dist_experiment_test.py", line 63, in main
    mp.spawn(ddp_test_v1, args=(world_size,), nprocs=world_size, join=True)
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
    while not context.join():
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 75, in join
    ready = multiprocessing.connection.wait(
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/home/lart/miniconda3/envs/pt17/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt

Process finished with exit code 137 (interrupted by signal 9: SIGKILL)

會在第二個 barrier  處無限等待下去。
這一特點在這個回答中也被提到了:

when a process encounters a barrier it will block the position of the barrier is not important (not all processes have to enter the same if-statement, for instance) a process is blocked by a barrier until all processes have encountered a barrier, upon which the barrier is lifted for all processes

https://stackoverflow.com/a/59766443

重要的參考資料