解決載入GPT2(Tensorflow預訓練模型)的Linear權重到PyTorch的Linear權重 形狀不匹配(互為轉置)問題

惋奈發表於2024-04-17

解決報錯內容:

RuntimeError: Error(s) in loading state_dict for PyTorchBasedGPT2:

size mismatch for transformer.h.0.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768])......

一、錯誤原因分析

Pytorch中,Linear層的權重儲存形狀為[out_features, in_features]。而Tensorflow中Linear權重的儲存形狀為[in_features, out_features]。

這是由於兩個庫使用不同的數學運算表示 (參考https://www.null123.com/question/detail-2816063.html):

Pytorch: y = Wx + B

Tensorflow: y = xW + B

當直接使用pytorch實現的GPT2架構模型去載入GPT2的預訓練引數時會發生:

1 PyTorchBasedGPT2.from_pretrained("openai-community/gpt2")
解決載入GPT2(Tensorflow預訓練模型)的Linear權重到PyTorch的Linear權重 形狀不匹配(互為轉置)問題
 1 RuntimeError: Error(s) in loading state_dict for PyTorchBasedGPT2:
 2     size mismatch for transformer.h.0.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
 3     size mismatch for transformer.h.0.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
 4     size mismatch for transformer.h.0.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
 5     size mismatch for transformer.h.1.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
 6     size mismatch for transformer.h.1.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
 7     size mismatch for transformer.h.1.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
 8     size mismatch for transformer.h.2.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
 9     size mismatch for transformer.h.2.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
10     size mismatch for transformer.h.2.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
11     size mismatch for transformer.h.3.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
12     size mismatch for transformer.h.3.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
13     size mismatch for transformer.h.3.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
14     size mismatch for transformer.h.4.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
15     size mismatch for transformer.h.4.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
16     size mismatch for transformer.h.4.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
17     size mismatch for transformer.h.5.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
18     size mismatch for transformer.h.5.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
19     size mismatch for transformer.h.5.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
20     size mismatch for transformer.h.6.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
21     size mismatch for transformer.h.6.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
22     size mismatch for transformer.h.6.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
23     size mismatch for transformer.h.7.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
24     size mismatch for transformer.h.7.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
25     size mismatch for transformer.h.7.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
26     size mismatch for transformer.h.8.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
27     size mismatch for transformer.h.8.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
28     size mismatch for transformer.h.8.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
29     size mismatch for transformer.h.9.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
30     size mismatch for transformer.h.9.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
31     size mismatch for transformer.h.9.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
32     size mismatch for transformer.h.10.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
33     size mismatch for transformer.h.10.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
34     size mismatch for transformer.h.10.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
35     size mismatch for transformer.h.11.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
36     size mismatch for transformer.h.11.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
37     size mismatch for transformer.h.11.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
38     You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.
View Error

二、解決方法

這時需要將原本的權重轉置後再使用Model.from_pretrained()載入模型。

1. 從Huggingface上拉模型,model_path為huggingface的repo名

1 model_path = "openai-community/gpt2"
2 model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)

2. 轉置原始權重中Linear的權重矩陣

  如果不確定如何獲取矩陣可以先輸出模型檢視一下:

1 print(model)

獲取權重並轉置,在這裡需要轉置attn中的c_attn和c_proj,mlp中的c_fc和c_proj。(這幾層看起來是卷積,但是程式碼實現實際上就是Linear層)

1 for layer in model.transformer.h:
2      layer.attn.c_attn.weight = torch.nn.Parameter(layer.attn.c_attn.weight.transpose(0, 1).contiguous()) # .contiguous()負責返回一個資料相同但記憶體佈局連續的新張量
3      layer.attn.c_proj.weight = torch.nn.Parameter(layer.attn.c_proj.weight.transpose(0, 1).contiguous())
4      layer.mlp.c_fc.weight = torch.nn.Parameter(layer.mlp.c_fc.weight.transpose(0, 1).contiguous())
5      layer.mlp.c_proj.weight = torch.nn.Parameter(layer.mlp.c_proj.weight.transpose(0, 1).contiguous())

3. 最後儲存model到指定路徑

1 output_dir = "new_gpt2"
2 model.save_pretrained(output_dir)

這樣在pytorch實現的類GPT2模型載入引數時就可以順利從指定路徑載入了:

1 model = PyTorchBasedGPT2.from_pretrained("new_gpt2")
2 print(model)

得到模型:

相關文章