Wenet多機多卡分散式訓練

ByteHandler發表於2023-01-09

Wenet多機多卡分散式訓練

PyTorch分散式訓練Demo

Wenet框架基於PyTorch實現,因此wenet多機多卡訓練依賴於PyTorch分散式訓練的實現。

下面程式碼展示瞭如何基於PyTorch進行分散式訓練:

def ddp_demo(rank, world_size, accum_grad=4):
    assert dist.is_gloo_available(), "Gloo is not available!"
    print(f"world_size: {world_size}, rank: {rank}, is_gloo_available: {dist.is_gloo_available()}")

    # 1. 初始化程式組
    dist.init_process_group("gloo", world_size=world_size, rank=rank)
    model = nn.Sequential(nn.Linear(10, 100), nn.ReLU(), nn.Linear(100, 20))

    # 2. 模型轉化成ddp模型
    ddp_model = DistributedDataParallel(model)

    criterion = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=1e-3)

    dataset = TensorDataset(torch.randn(1000, 10))
    # 3. 資料分散式並行(內部會根據rank取樣)
    sampler = DistributedSampler(dataset=dataset, num_replicas=world_size, shuffle=True)
    dataloader = DataLoader(dataset=dataset, batch_size=24, sampler=sampler, collate_fn=transform)

    for epoch in range(1):
        for step, batch in enumerate(dataloader):
            output = ddp_model(batch)
            label = torch.rand_like(output)

            if step % accum_grad == 0:
                # 同步引數
                context = contextlib.nullcontext
            else:
                # 4. 梯度累計,不同步引數
                context = ddp_model.no_sync

            with context():
                time.sleep(random.random())
                loss = criterion(output, label)
                loss.backward()

            if step % accum_grad == 0:
                optimizer.step()
                optimizer.zero_grad()
                print(f"epoch: {epoch}, step: {step}, rank: {rank} update parameters.")

    # 5. 銷燬程式組上下文資料(一些全域性變數)
    dist.destroy_process_group()

本地環境沒有Nvidia顯示卡,用gloo後端替代nccl

原始碼參考:https://gist.github.com/hotbaby/15950bbb43d052cd835b0f18c997f67c

模型轉換成分散式訓練的步驟:

  1. 初始化程式組dist.init_process_group
  2. 分散式資料並行封裝模型DistributedDataParallel(model)
  3. 資料分散式並行,將資料分成world_size 份,根據rank取樣DistributedSampler(dataset=dataset, num_replicas=world_size, shuffle=True)
  4. 訓練過程中梯度累計,降低訓練程式間的引數同步頻率,提升通訊效率【可選】;
  5. 銷燬程式組dist.destroy_process_group()

Wenet分散式訓練實踐

Wenet如何配置多機多卡分散式訓練?

GPU機器列表:

節點名稱 IP地址 GPU數量
node1 10.10.23.9 8
node2 10.10.23.10 8

以aishell資料集為例,說明Wenet框架中文ASR模型在GPU機器上的訓練過程:

  1. 環境初始化和資料準備

    環境初始化參考Wenet官方文件https://github.com/wenet-e2e/wenet#installationtraining-and-developing。

    aishell資料集解壓後,分別複製node1和node2兩臺機器的/data/aishell/目錄。

  2. 配置訓練指令碼配置

    node1訓練指令碼配置:

    wenet/examples/aishell/s0/run.sh

export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
data=/data/aishell/
num_nodes=2
node_rank=0
init_method="tcp://${node1_ip}:23456"
dist_backend="nccl"

node2訓練指令碼配置:

wenet/examples/aishell/s0/run.sh

export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
num_nodes=2
node_rank=1
init_method="tcp://${node1_ip}:23456"
dist_backend="nccl"
  1. 執行訓練指令碼

    分別在node1和node2上後臺執行run.sh訓練指令碼。

    # export NCCL_SOCKET_IFNAME=ens1f0
    nohup bash run.sh > train.log 2>&1 &
    

    ens1f0為網路卡介面名字,如果沒有配置,可能會導致多機網路通訊問題。

Wenet分散式訓練實驗結果

GPU配置 每個Epoch的訓練時間(秒) 速度提升
單機多卡(4) 407.17
單機多卡(8) 204.36 相比單機多卡(4)提升99.24%
多機多卡(8) 221.75 相比單機多卡(8)慢了7.84%
多機多卡(16) 121.7 相比單機多卡(8)提升了67.92%

Wenet分散式訓練如何實現?

與上述DDP Demo類似,Wenet呼叫PyTorch相關介面實現分散式訓練。

  1. 初始化程式組

wenet/bin/train.py

def main():
    ...
    if distributed:
        logging.info('training on multiple gpus, this gpu {}'.format(args.gpu))
        dist.init_process_group(args.dist_backend,
                                init_method=args.init_method,
                                world_size=args.world_size,
                                rank=args.rank)
	...

Wenet原始碼連結:https://github.com/wenet-e2e/wenet/blob/main/wenet/bin/train.py#L141

  1. 分散式資料並行封裝模型
def main():
    ...
	if distributed:
        assert (torch.cuda.is_available())
        # cuda model is required for nn.parallel.DistributedDataParallel
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(
            model, find_unused_parameters=True)
    ...

Wenet原始碼連結:https://github.com/wenet-e2e/wenet/blob/main/wenet/bin/train.py#L232

  1. 資料分散式並行

wenet/dataset/dataset.py

class DistributedSampler:
    ...
    def sample(self, data):
        """ Sample data according to rank/world_size/num_workers
            Args:
                data(List): input data list
            Returns:
                List: data list after sample
        """
        data = list(range(len(data)))
        # TODO(Binbin Zhang): fix this
        # We can not handle uneven data for CV on DDP, so we don't
        # sample data by rank, that means every GPU gets the same
        # and all the CV data
        if self.partition:
            if self.shuffle:
                random.Random(self.epoch).shuffle(data)
            data = data[self.rank::self.world_size]
        # num_workers引數與world_size相等,按world_size進行切片。
        data = data[self.worker_id::self.num_workers]
        return data
    ...

Wenet原始碼連結:https://github.com/wenet-e2e/wenet/blob/main/wenet/dataset/dataset.py#L79

  1. 梯度累積,降低訓練程式引數同步頻率

wenet/utils/executor.py

class Executor:
	def train(...):
        with model_context():
            for batch_idx, batch in enumerate(data_loader):
				if is_distributed and batch_idx % accum_grad != 0:
                    # 梯度累計,不同步引數
                    context = model.no_sync
                # Used for single gpu training and DDP gradient synchronization
                # processes.
                else:
                    # 同步引數
                    context = nullcontext
    			with context():
                    # autocast context
                    # The more details about amp can be found in
                    # https://pytorch.org/docs/stable/notes/amp_examples.html
                    with torch.cuda.amp.autocast(scaler is not None):
                        loss_dict = model(feats, feats_lengths, target,
                                          target_lengths)
                        loss = loss_dict['loss'] / accum_grad
                    if use_amp:
                        scaler.scale(loss).backward()
                    else:
                        loss.backward()

Wenet原始碼連結:https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/executor.py#L67

  1. 銷燬程式組,Wenet原始碼中沒有呼叫PyTorch的destroy_process_group()方法,因為訓練程式退出後,process_group相關全域性變數和上下文會自然銷燬,所以不會影響訓練過程。

Wenet分散式訓練對一些超參的影響?

多機多卡(16卡)相關對於單機多卡(4卡)開發集loss收斂速度變慢?

img

調整wenet/examples/aishell/s0/conf/train_conformer.yamlwarmup_steps引數可以解決此問題。

optim_conf:
    lr: 0.002
scheduler: warmuplr     # pytorch v1.1.0+ required
scheduler_conf:
    warmup_steps: 1562

如何調整梯度累計的間隔?

調整wenet/examples/aishell/s0/conf/train_conformer.yamlaccum_grad引數。

相關文章