解決報錯內容:
RuntimeError: Error(s) in loading state_dict for PyTorchBasedGPT2:
size mismatch for transformer.h.0.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768])......
一、錯誤原因分析
Pytorch中,Linear層的權重儲存形狀為[out_features, in_features]。而Tensorflow中Linear權重的儲存形狀為[in_features, out_features]。
這是由於兩個庫使用不同的數學運算表示 (參考https://www.null123.com/question/detail-2816063.html):
Pytorch: y = Wx + B
Tensorflow: y = xW + B
當直接使用pytorch實現的GPT2架構模型去載入GPT2的預訓練引數時會發生:
1 PyTorchBasedGPT2.from_pretrained("openai-community/gpt2")
1 RuntimeError: Error(s) in loading state_dict for PyTorchBasedGPT2: 2 size mismatch for transformer.h.0.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 3 size mismatch for transformer.h.0.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 4 size mismatch for transformer.h.0.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 5 size mismatch for transformer.h.1.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 6 size mismatch for transformer.h.1.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 7 size mismatch for transformer.h.1.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 8 size mismatch for transformer.h.2.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 9 size mismatch for transformer.h.2.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 10 size mismatch for transformer.h.2.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 11 size mismatch for transformer.h.3.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 12 size mismatch for transformer.h.3.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 13 size mismatch for transformer.h.3.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 14 size mismatch for transformer.h.4.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 15 size mismatch for transformer.h.4.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 16 size mismatch for transformer.h.4.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 17 size mismatch for transformer.h.5.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 18 size mismatch for transformer.h.5.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 19 size mismatch for transformer.h.5.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 20 size mismatch for transformer.h.6.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 21 size mismatch for transformer.h.6.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 22 size mismatch for transformer.h.6.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 23 size mismatch for transformer.h.7.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 24 size mismatch for transformer.h.7.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 25 size mismatch for transformer.h.7.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 26 size mismatch for transformer.h.8.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 27 size mismatch for transformer.h.8.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 28 size mismatch for transformer.h.8.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 29 size mismatch for transformer.h.9.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 30 size mismatch for transformer.h.9.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 31 size mismatch for transformer.h.9.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 32 size mismatch for transformer.h.10.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 33 size mismatch for transformer.h.10.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 34 size mismatch for transformer.h.10.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 35 size mismatch for transformer.h.11.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]). 36 size mismatch for transformer.h.11.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]). 37 size mismatch for transformer.h.11.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]). 38 You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.
二、解決方法
這時需要將原本的權重轉置後再使用Model.from_pretrained()載入模型。
1. 從Huggingface上拉模型,model_path為huggingface的repo名
1 model_path = "openai-community/gpt2" 2 model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
2. 轉置原始權重中Linear的權重矩陣
如果不確定如何獲取矩陣可以先輸出模型檢視一下:
1 print(model)
獲取權重並轉置,在這裡需要轉置attn中的c_attn和c_proj,mlp中的c_fc和c_proj。(這幾層看起來是卷積,但是程式碼實現實際上就是Linear層)
1 for layer in model.transformer.h: 2 layer.attn.c_attn.weight = torch.nn.Parameter(layer.attn.c_attn.weight.transpose(0, 1).contiguous()) # .contiguous()負責返回一個資料相同但記憶體佈局連續的新張量 3 layer.attn.c_proj.weight = torch.nn.Parameter(layer.attn.c_proj.weight.transpose(0, 1).contiguous()) 4 layer.mlp.c_fc.weight = torch.nn.Parameter(layer.mlp.c_fc.weight.transpose(0, 1).contiguous()) 5 layer.mlp.c_proj.weight = torch.nn.Parameter(layer.mlp.c_proj.weight.transpose(0, 1).contiguous())
3. 最後儲存model到指定路徑
1 output_dir = "new_gpt2" 2 model.save_pretrained(output_dir)
這樣在pytorch實現的類GPT2模型載入引數時就可以順利從指定路徑載入了:
1 model = PyTorchBasedGPT2.from_pretrained("new_gpt2") 2 print(model)
得到模型: