【BERT】你儲存的BERT模型為什麼那麼大?
前一段時間有個朋友問我這樣一個問題:google官網給的bert-base模型的ckpt檔案大小隻有400M,為什麼我進行微調-訓練之後,儲存的ckpt模型就是1.19G呢?
我當時的回答是:因為google給的bert-base模型的ckpt檔案僅包含bert的transform每一層的引數,不包含其他引數。而你自己在微調訓練過程中有增加了其他的一些引數,所以會比較大。
現在想一想,感覺自己的回答也有些模稜兩可,也許也有其他同學會有這也的疑問吧。那麼我今天就詳細講解一些,為什麼儲存時會多出那麼多引數?
實踐是檢驗真理的唯一標準。
我們首先透過tf.train.NewCheckpointReader來直接讀取google官網給的bert-base模型的ckpt檔案,這種方法的好處是,我們不需要重新載入model模型,就可以看到儲存的所有節點。
from tensorflow.python import pywrap_tensorflow
ckpt_model_path = "chinese_L-12_H-768_A-12/bert_model.ckpt"
model_ckpt = pywrap_tensorflow.NewCheckpointReader(ckpt_model_path)
var_dict = model_ckpt.get_variable_to_shape_map()
for key in var_dict:
print("bert_parameter:", key)
得到的結果如下(由於引數過多,因此只列出部分引數):
bert_parameter: bert/embeddings/LayerNorm/beta
bert_parameter: bert/embeddings/LayerNorm/gamma
bert_parameter: bert/encoder/layer_9/attention/output/LayerNorm/beta
bert_parameter: bert/encoder/layer_9/attention/output/dense/bias
bert_parameter: bert/encoder/layer_9/attention/output/dense/kernel
bert_parameter: bert/encoder/layer_9/attention/self/key/kernel
bert_parameter: bert/encoder/layer_9/attention/self/query/bias
bert_parameter: bert/encoder/layer_9/attention/self/query/kernel
bert_parameter: bert/encoder/layer_9/intermediate/dense/bias
bert_parameter: bert/encoder/layer_9/intermediate/dense/kernel
bert_parameter: bert/encoder/layer_9/output/LayerNorm/beta
bert_parameter: bert/encoder/layer_9/output/dense/bias
bert_parameter: bert/encoder/layer_9/output/dense/kernel
bert_parameter: bert/pooler/dense/bias
bert_parameter: bert/pooler/dense/kernel
bert_parameter: cls/predictions/transform/LayerNorm/beta
bert_parameter: cls/predictions/transform/LayerNorm/gamma
bert_parameter: cls/predictions/transform/dense/bias
bert_parameter: cls/predictions/transform/dense/kernel
bert_parameter: cls/seq_relationship/output_bias
bert_parameter: cls/seq_relationship/output_weights
我們可以發現,其實google給的bert-base模型的ckpt檔案, 不僅僅是儲存了bert-transform每一層的引數,而且也儲存了embedding和預訓練需要的預測引數(例如:在進行NPS預測時所需的全連線層引數)。因此,我之前給那位朋友的回答是存在偏差的。
接下來,我們透過tf.train.NewCheckpointReader來讀取我們fine-tune之後的模型,看一下都儲存了什麼引數。
from tensorflow.python import pywrap_tensorflow
ckpt_model_path = "my_model\\bert_model.ckpt"
model_ckpt = pywrap_tensorflow.NewCheckpointReader(ckpt_model_path)
var_dict = model_ckpt.get_variable_to_shape_map()
for key in var_dict:
print("bert_parameter:", key)
得到的結果如下(依然只列出部分引數):
bert_parameter: bert/embeddings/LayerNorm/beta
bert_parameter: bert/embeddings/LayerNorm/beta/adam_v
bert_parameter: bert/encoder/layer_9/attention/output/LayerNorm/beta/adam_m
bert_parameter: bert/encoder/layer_9/attention/output/LayerNorm/beta/adam_v
bert_parameter: bert/encoder/layer_9/attention/output/dense/bias/adam_m
bert_parameter: bert/encoder/layer_9/attention/output/dense/bias/adam_v
bert_parameter: bert/encoder/layer_9/attention/output/dense/kernel
bert_parameter: bert/encoder/layer_9/attention/output/dense/kernel/adam_m
bert_parameter: bert/encoder/layer_9/attention/output/dense/kernel/adam_v
bert_parameter: bert/encoder/layer_9/attention/self/key/kernel
bert_parameter: bert/encoder/layer_9/attention/self/key/kernel/adam_m
bert_parameter: bert/encoder/layer_9/attention/self/key/kernel/adam_v
bert_parameter: bert/encoder/layer_9/attention/self/query/kernel/adam_m
bert_parameter: bert/encoder/layer_9/attention/self/query/kernel/adam_v
bert_parameter: bert/encoder/layer_9/attention/self/value/bias
bert_parameter: bert/encoder/layer_9/attention/self/value/bias/adam_m
bert_parameter: bert/encoder/layer_9/attention/self/value/bias/adam_v
bert_parameter: bert/encoder/layer_9/attention/self/value/kernel
bert_parameter: bert/encoder/layer_9/attention/self/value/kernel/adam_m
bert_parameter: bert/encoder/layer_9/attention/self/value/kernel/adam_v
bert_parameter: bert/encoder/layer_9/intermediate/dense/bias
bert_parameter: bert/encoder/layer_9/intermediate/dense/bias/adam_m
bert_parameter: bert/encoder/layer_9/intermediate/dense/bias/adam_v
bert_parameter: bert/encoder/layer_9/output/LayerNorm/beta/adam_v
bert_parameter: bert/encoder/layer_9/output/dense/bias
bert_parameter: bert/encoder/layer_9/output/dense/kernel/adam_m
bert_parameter: bert/encoder/layer_9/output/dense/kernel/adam_v
bert_parameter: bert/pooler/dense/bias
bert_parameter: bert/pooler/dense/bias/adam_m
bert_parameter: bert/pooler/dense/bias/adam_v
bert_parameter: bert/pooler/dense/kernel
bert_parameter: bert/pooler/dense/kernel/adam_m
bert_parameter: bert/pooler/dense/kernel/adam_v
看到這個結果時,我相信大家應該都會恍然大悟, 其實我們在做微調的時候,並沒有新增多少引數變數。導致我們儲存的ckpt檔案到達1.19G的原因,其實是多儲存了每個變數的adam_m和adam_v。這樣算來,一個變數變成了3個變數,正好是從400M到1.19G。
接下來,應該會有同學問:adam_m和adam_v是什麼呢,為什麼會儲存這些引數?
答:在模型進行訓練最佳化(誤差傳遞)時,我們通常使用Adam最佳化器進行最佳化。在最佳化的過程中,我們通常需要維護(儲存下之前時刻的滑動平均值)每個引數的一階矩(對應adam_m)和二階矩(對應adam_v)來保證梯度的順利更新。
簡單地來講,就是在模型訓練的過程中,每個引數都需要額外變數引數儲存一些資訊,用於誤差的傳遞以及梯度的更新。而這些額外的變數引數,在訓練停止之後,其實也就失去了它的作;並且在模型預測階段或者呼叫該模型為其他模型進行引數初始化階段,都是不需要使用這些引數變數的。但是一般我們在儲存模型時,都會預設將所有引數變數都進行儲存,所有才會導致我們儲存的bert模型有1.19G。
為了使我們儲存的模型變小,減緩我們硬碟的壓力,我們可以在儲存模型時,進行如下操作:
tf.train.Saver(tf.trainable_variables()).save(sess, save_model_path)
這樣儲存的模型,就是隻包含訓練引數,而額外的儲存引數是不會進行儲存的。 僅修改一行程式碼,就可以減輕硬碟2/3的壓力,何樂而不為呢!
下面是我們平時使用儲存模型的程式碼,這樣儲存是將所有引數都儲存下來。
tf.train.Saver().save(sess, save_model_path)
等於
tf.train.Saver(tf.all_variables()).save(sess, save_model_path)
實踐是檢驗真理的唯一標準。有時,你認為的僅僅是你認為的;你做出來的,才是真的。
https://mp.weixin.qq.com/s/oNMHEPiH7StcgcJOqzwShQ
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69946223/viewspace-2679782/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 為什麼說 Bert 大力出奇跡?
- BERT微調進行命名實體識別並將模型儲存為pb形式模型
- 8.3 BERT模型介紹模型
- BERT 模型壓縮方法模型
- 【BERT】詳解BERT
- XLNet 第一作者楊植麟:為什麼預處理模型XLNet比BERT、RoBERTa更加優越模型
- PyTorch預訓練Bert模型PyTorch模型
- Bert下載和使用(以bert-base-uncased為例)
- 什麼是YottaChain儲存,為什麼說是未來資料儲存的趨勢?AI
- 為什麼你存不下錢?
- 為什麼Kubernetes的儲存如此艱難?
- XLM — 基於BERT的跨語言模型模型
- 從字到詞,大詞典中文BERT模型的探索之旅模型
- 張俊林:BERT和Transformer到底學到了什麼 | AI ProCon 2019ORMAI
- Redis為什麼那麼快?Redis
- NER為什麼那麼難
- 哪有那麼多為什麼?
- NLP與深度學習(六)BERT模型的使用深度學習模型
- 塊儲存是做什麼用的,你知道嗎?
- 物件儲存的優勢有哪些?為什麼要選擇物件儲存?物件
- 效能媲美BERT,引數量僅為1/300,谷歌最新的NLP模型谷歌模型
- 為什麼 python 那麼熱門Python
- Kafka為什麼速度那麼快?Kafka
- 什麼是物件儲存?物件
- 資料庫mysql儲存是什麼?可以存什麼?資料庫MySql
- transformers(1) 、bertORM
- 程式設計為什麼那麼難:從儲值卡扣款說起程式設計
- 真實案例:使用LLM大模型及BERT模型實現合同審查系統大模型
- 為什麼不用資料庫儲存圖片?資料庫
- 你構建的程式碼為什麼這麼大
- 開啟NLP新時代的BERT模型,是怎麼一步步封神的?模型
- 為什麼京東上的商品評論視訊不能直接儲存?怎麼樣可以快速儲存
- 替代 VMware ,為什麼需要重新考慮您的儲存?
- 我的BERT!改改字典,讓BERT安全提速不掉分(已開源)
- 推薦那麼準,除了模型,還有什麼。。。模型
- 深入淺出騰訊BERT推理模型--TurboTransformers模型ORM
- 為什麼遊戲DLC的精品那麼少?遊戲
- 遊戲的留存為什麼那麼難調?遊戲