【BERT】你儲存的BERT模型為什麼那麼大?

AIBigbull2050發表於2020-03-13


前一段時間有個朋友問我這樣一個問題:google官網給的bert-base模型的ckpt檔案大小隻有400M,為什麼我進行微調-訓練之後,儲存的ckpt模型就是1.19G呢?

我當時的回答是:因為google給的bert-base模型的ckpt檔案僅包含bert的transform每一層的引數,不包含其他引數。而你自己在微調訓練過程中有增加了其他的一些引數,所以會比較大。

現在想一想,感覺自己的回答也有些模稜兩可,也許也有其他同學會有這也的疑問吧。那麼我今天就詳細講解一些,為什麼儲存時會多出那麼多引數?

實踐是檢驗真理的唯一標準。

我們首先透過tf.train.NewCheckpointReader來直接讀取google官網給的bert-base模型的ckpt檔案,這種方法的好處是,我們不需要重新載入model模型,就可以看到儲存的所有節點。


from tensorflow.python import pywrap_tensorflow


ckpt_model_path = "chinese_L-12_H-768_A-12/bert_model.ckpt"
model_ckpt = pywrap_tensorflow.NewCheckpointReader(ckpt_model_path)
var_dict = model_ckpt.get_variable_to_shape_map()
for key in var_dict:
   print("bert_parameter:", key)

得到的結果如下(由於引數過多,因此只列出部分引數):


bert_parameter: bert/embeddings/LayerNorm/beta

bert_parameter: bert/embeddings/LayerNorm/gamma
bert_parameter: bert/encoder/layer_9/attention/output/LayerNorm/beta
bert_parameter: bert/encoder/layer_9/attention/output/dense/bias
bert_parameter: bert/encoder/layer_9/attention/output/dense/kernel
bert_parameter: bert/encoder/layer_9/attention/self/key/kernel
bert_parameter: bert/encoder/layer_9/attention/self/query/bias
bert_parameter: bert/encoder/layer_9/attention/self/query/kernel
bert_parameter: bert/encoder/layer_9/intermediate/dense/bias
bert_parameter: bert/encoder/layer_9/intermediate/dense/kernel
bert_parameter: bert/encoder/layer_9/output/LayerNorm/beta
bert_parameter: bert/encoder/layer_9/output/dense/bias
bert_parameter: bert/encoder/layer_9/output/dense/kernel
bert_parameter: bert/pooler/dense/bias
bert_parameter: bert/pooler/dense/kernel
bert_parameter: cls/predictions/transform/LayerNorm/beta
bert_parameter: cls/predictions/transform/LayerNorm/gamma
bert_parameter: cls/predictions/transform/dense/bias
bert_parameter: cls/predictions/transform/dense/kernel
bert_parameter: cls/seq_relationship/output_bias
bert_parameter: cls/seq_relationship/output_weights

我們可以發現,其實google給的bert-base模型的ckpt檔案, 不僅僅是儲存了bert-transform每一層的引數,而且也儲存了embedding和預訓練需要的預測引數(例如:在進行NPS預測時所需的全連線層引數)。因此,我之前給那位朋友的回答是存在偏差的。

接下來,我們透過tf.train.NewCheckpointReader來讀取我們fine-tune之後的模型,看一下都儲存了什麼引數。


from tensorflow.python import pywrap_tensorflow


ckpt_model_path = "my_model\\bert_model.ckpt"
model_ckpt = pywrap_tensorflow.NewCheckpointReader(ckpt_model_path)
var_dict = model_ckpt.get_variable_to_shape_map()
for key in var_dict:
   print("bert_parameter:", key)

得到的結果如下(依然只列出部分引數):


bert_parameter: bert/embeddings/LayerNorm/beta

bert_parameter: bert/embeddings/LayerNorm/beta/adam_v
bert_parameter: bert/encoder/layer_9/attention/output/LayerNorm/beta/adam_m
bert_parameter: bert/encoder/layer_9/attention/output/LayerNorm/beta/adam_v
bert_parameter: bert/encoder/layer_9/attention/output/dense/bias/adam_m
bert_parameter: bert/encoder/layer_9/attention/output/dense/bias/adam_v
bert_parameter: bert/encoder/layer_9/attention/output/dense/kernel
bert_parameter: bert/encoder/layer_9/attention/output/dense/kernel/adam_m
bert_parameter: bert/encoder/layer_9/attention/output/dense/kernel/adam_v
bert_parameter: bert/encoder/layer_9/attention/self/key/kernel
bert_parameter: bert/encoder/layer_9/attention/self/key/kernel/adam_m
bert_parameter: bert/encoder/layer_9/attention/self/key/kernel/adam_v
bert_parameter: bert/encoder/layer_9/attention/self/query/kernel/adam_m
bert_parameter: bert/encoder/layer_9/attention/self/query/kernel/adam_v
bert_parameter: bert/encoder/layer_9/attention/self/value/bias
bert_parameter: bert/encoder/layer_9/attention/self/value/bias/adam_m
bert_parameter: bert/encoder/layer_9/attention/self/value/bias/adam_v
bert_parameter: bert/encoder/layer_9/attention/self/value/kernel
bert_parameter: bert/encoder/layer_9/attention/self/value/kernel/adam_m
bert_parameter: bert/encoder/layer_9/attention/self/value/kernel/adam_v
bert_parameter: bert/encoder/layer_9/intermediate/dense/bias
bert_parameter: bert/encoder/layer_9/intermediate/dense/bias/adam_m
bert_parameter: bert/encoder/layer_9/intermediate/dense/bias/adam_v
bert_parameter: bert/encoder/layer_9/output/LayerNorm/beta/adam_v
bert_parameter: bert/encoder/layer_9/output/dense/bias
bert_parameter: bert/encoder/layer_9/output/dense/kernel/adam_m
bert_parameter: bert/encoder/layer_9/output/dense/kernel/adam_v
bert_parameter: bert/pooler/dense/bias
bert_parameter: bert/pooler/dense/bias/adam_m
bert_parameter: bert/pooler/dense/bias/adam_v
bert_parameter: bert/pooler/dense/kernel
bert_parameter: bert/pooler/dense/kernel/adam_m
bert_parameter: bert/pooler/dense/kernel/adam_v

看到這個結果時,我相信大家應該都會恍然大悟, 其實我們在做微調的時候,並沒有新增多少引數變數。導致我們儲存的ckpt檔案到達1.19G的原因,其實是多儲存了每個變數的adam_m和adam_v。這樣算來,一個變數變成了3個變數,正好是從400M到1.19G。

接下來,應該會有同學問:adam_m和adam_v是什麼呢,為什麼會儲存這些引數?

答:在模型進行訓練最佳化(誤差傳遞)時,我們通常使用Adam最佳化器進行最佳化。在最佳化的過程中,我們通常需要維護(儲存下之前時刻的滑動平均值)每個引數的一階矩(對應adam_m)和二階矩(對應adam_v)來保證梯度的順利更新。

簡單地來講,就是在模型訓練的過程中,每個引數都需要額外變數引數儲存一些資訊,用於誤差的傳遞以及梯度的更新。而這些額外的變數引數,在訓練停止之後,其實也就失去了它的作;並且在模型預測階段或者呼叫該模型為其他模型進行引數初始化階段,都是不需要使用這些引數變數的。但是一般我們在儲存模型時,都會預設將所有引數變數都進行儲存,所有才會導致我們儲存的bert模型有1.19G。

為了使我們儲存的模型變小,減緩我們硬碟的壓力,我們可以在儲存模型時,進行如下操作:

tf.train.Saver(tf.trainable_variables()).save(sess, save_model_path)

這樣儲存的模型,就是隻包含訓練引數,而額外的儲存引數是不會進行儲存的。 僅修改一行程式碼,就可以減輕硬碟2/3的壓力,何樂而不為呢!

下面是我們平時使用儲存模型的程式碼,這樣儲存是將所有引數都儲存下來。


tf.train.Saver().save(sess, save_model_path)

等於
tf.train.Saver(tf.all_variables()).save(sess, save_model_path)

實踐是檢驗真理的唯一標準。有時,你認為的僅僅是你認為的;你做出來的,才是真的。


https://mp.weixin.qq.com/s/oNMHEPiH7StcgcJOqzwShQ




來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946223/viewspace-2679782/,如需轉載,請註明出處,否則將追究法律責任。

相關文章