Wenet分散式訓練對學習率調整的影響

ByteHandler發表於2023-01-09

Wenet分散式訓練對學習率調整的影響

背景

Wenet多機多卡分散式訓練時,發現多機多卡(16卡)開發集loss收斂速度遠遠慢於單機多卡(4卡)。

分散式訓練收斂速度和學習率變化的關係

Tensorboard視覺化分佈訓練開發集loss收斂和學習率的變化過程:

訓練學習率相關引數:

conf/train_conformer.yaml

optim: adam
optim_conf:
    lr: 0.002
scheduler: warmuplr     # pytorch v1.1.0+ required
scheduler_conf:
    warmup_steps: 25000
  • 紅色線條代表1機4卡,warmup_steps為25000
  • 紫色線條代表2機16卡,warmup_steps為25000
  • 藍色線條代表2機16卡,warmup_steps為1652

結論:隨著多機多卡分散式訓練卡數量增加,每個Epoch的step數量減少,Warmup學習率的調整變慢,進而導致收斂速度變慢。根據訓練卡的數量調整warmup_steps後,2機16卡與1機4卡的收斂速度接近。

Wenet學習率調整策略分析

Wenet Warmup學習率原始碼分析

wenet/utils/scheduler.py

class WarmupLR(_LRScheduler):
    ...
    def get_lr(self):
        step_num = self.last_epoch + 1
        if self.warmup_steps == 0:
            # 不進行學習率的warmup,學習率根據step增加而衰減,衰減函式是平方根函式的倒數
            return [
                lr * step_num ** -0.5
                for lr in self.base_lrs
            ]
        else:
            # 先進行學習率的warmup(線性增長),再進行學習率的衰減。
            return [
                lr
                * self.warmup_steps ** 0.5
                * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)
                for lr in self.base_lrs
            ]

    def set_step(self, step: int):
        self.last_epoch = step

Wenet學習率調整公式:

\[ f(step) = \left\{ \begin{array} \\ baseLR \cdot \frac{1}{\sqrt{warmupSteps}} \cdot \frac{step}{warmupSteps^{\frac{3}{2}}} & {step <= warmupSteps}\\ baseLR \cdot \frac{1}{\sqrt{warmupSteps}} \cdot \frac{1}{\sqrt{step}} & {step > warmupSteps}\\ \end{array} \right.\]

  • step小於warmupSteps時,學習率隨著step線性增長,直到baseLR
  • step大於warmupSteps時,學習率隨著step增加而衰減,衰減函式是平方根函式的倒數。

模擬Wenet 預熱學習率調整策略

# 學習率調整函式
def f(step, lr=1e-3, warmup_steps=25000):
    next_lr = lr * warmup_steps ** 0.5
    
    if step < warmup_steps:
        return next_lr * step * warmup_steps ** -1.5
    else:
        return next_lr * step ** -0.5


x = list(range(1, 200000))
y = list(map(f, x))

# 每個Epoch有1000個Steps
epochs = list(map(lambda x: int(x / 1000), x))

fig, ax = plt.subplots()
ax.plot(epochs, y)
ax.set_xlabel("Epoch")
ax.set_ylabel("學習率")
ax.set_title("模擬WarmupLearningRate隨Epoch變化")
plt.show()

參考文獻

相關文章