1. 背景
根據本qiang~最新的趨勢觀察,基於MoE架構的開源大模型越來越多,比如馬斯克的Grok-1(314B), Qwen1.5-MoE-A2.7B等,因此想探究一下MoE裡面的部分細節。
此文是本qiang~針對大語言模型的MoE的整理,包括原理、流程及部分原始碼。
2. MoE原理
MoE的流行源於”歐洲的OpenAI” Mistral AI釋出的論文及模型《Mixtral of Experts》,評測集上的效果吊打眾多開源模型,如Llama 2 70B和GPT3.5。
《Mixtral of Experts》基礎模型使用的是Mistral AI自研的Mistral 7B,該模型的特點包括:滑窗注意力(Sliding Window Aattention), 滾動緩衝區快取(Rolling Buffer Cache)以及預填充-分塊(Pre-fill and Chunking),具體細節可以查閱文末的論文地址。
本文以《Mixtral of Experts》為引子,探究MoE的相關細節,MoE的原理如下圖所示:
圖2.1 MoE的原理
(1) Transformers架構中的每一層中的FFN網路均替換為了8個FFN(專家),且由一個閘道器路由(gate router)進行控制
(2) 針對每一個token,每一層的閘道器路由僅選擇其中的2個FFN(專家)來處理當前狀態並進行加權輸出
(3) 結果就是,每一個token訪問了47B引數,但是在推理階段僅僅使用了13B的啟用引數(即,只使用2個專家,凍結其他6個專家)。
(4) 與Dropout機制對比,Dropout讓部分神經元失活,而MoE是讓部分專家失活。
3. 原始碼
本qiang~研讀並嘗試執行了Mistral官網的github推理程式碼,該程式碼框架非常適合新手,無他,只因其幾乎只是在torch上層做的封裝,很少引擎其他第三方庫,不像transformers,功能強大,但不適合新手研讀程式碼…
為了普適性,下面的程式碼擷取了transformers框架中的程式碼。
首先看下通用Transformers中FFN中的程式碼模組,程式碼位置在transformers.models.mistral.modeling_mistral, 主要流程是:
(1) 先經過gate_proj和up_proj的2個[hidden_size, intermediate_size]的線性轉換
(2) 使用啟用函式對gate_proj進行啟用
(3) 二者的內積再經過down_proj線性轉換。
1 class MistralMLP(nn.Module): 2 def __init__(self, config): 3 super().__init__() 4 self.config = config 5 self.hidden_size = config.hidden_size 6 self.intermediate_size = config.intermediate_size 7 self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 8 self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 9 self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 10 self.act_fn = ACT2FN[config.hidden_act] 11 12 def forward(self, x): 13 return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
再來看下MoE中的專家模組,程式碼位置在transformers.models.mixtral.modeling_mixtral,主要流程是:
(1) 首先經過閘道器路由self.gate
(2) 然後選擇其中2個專家,並歸一化
(3) 之後遍歷每個專家網路,並按照expert_mask進行篩選
(4) 如果expert_mask有值,則選擇指定部分的隱藏層進行FFN操作,且輸出結果進行加權
(5) 最後原地增加先前初始化的最終結果變數final_hidden_states
class MixtralSparseMoeBlock(nn.Module): def __init__(self, config): super().__init__() self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ """ batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) if top_x.shape[0] == 0: continue # in torch it is faster to index using lists than torch tensors top_x_list = top_x.tolist() idx_list = idx.tolist() # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits
其中MixtralBlockSparseTop2MLP程式碼如下,可以看到和傳統MistralMLP內容完全一致。
class MixtralBlockSparseTop2MLP(nn.Module): def __init__(self, config: MixtralConfig): super().__init__() self.ffn_dim = config.intermediate_size self.hidden_dim = config.hidden_size self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) current_hidden_states = self.w2(current_hidden_states) return current_hidden_states
4. MoE微調
由於MoE只是將每一層的FFN改變為了每一層的gate閘道器路由+8個FFN專家,且gate閘道器路由和8個專家內部均為線性運算,所以可以無縫地結合LoRA、QLoRA進行指令微調。
可以參考開源專案:https://github.com/yangjianxin1/Firefly
5. 答疑解惑
(1) 問:MoE 8*7B的模型是56B引數?
答:MoE 8*7B的引數量是47B,而不是56B,原因是每一層除了8個專家網路外,其他層均是複用的。
(2) 問:MoE的基礎模型是Mistral 7B?
答:不是,MoE的模型架構與Mistral 7B相同,但其中的FFN替換為了8個FFN,且MoE是基於多語言資料集預訓練而來的。
(3) MoE的稀疏性(sparse)體現在哪裡?
答:在訓練和推理時,同時只有兩個專家網路會被啟用,進行前向計算,其它專家網路處於失活狀態。
6. 總結
一句話足矣~
本文主要針對大語言模型的MoE,包括原理及部分原始碼。
此外,建議大家可以針對原始碼進行執行,關於原始碼,歡迎大家一塊交流。
7. 參考
(1) Mistral 7B:https://arxiv.org/pdf/2310.06825v1.pdf
(2) MoE: https://arxiv.org/pdf/2401.04088v1.pdf
(3) MoE開源指令微調框架Firefly: https://github.com/yangjianxin1/Firefly