獲取和生成基於TensorFlow的MobilNet預訓練模型
從TensorFlow-Slim image classification model library中找到我們要使用的MobileNet_v2_1.0_224^*(1.0-阿爾法值,244-輸入圖片shape),下載mobilenet_v2_1.0_224.tgz檔案,其中內含:
- mobilenet_v2_1.0_224.ckpt.data-00000-of-00001
- mobilenet_v2_1.0_224.ckpt.index
- mobilenet_v2_1.0_224.ckpt.meta
- mobilenet_v2_1.0_224.tflite
- mobilenet_v2_1.0_224_eval.pbtxt
- mobilenet_v2_1.0_224_frozen.pb
- mobilenet_v2_1.0_224_info.txt
mobilenet_v2_1.0_224.ckpt.data-00000-of-00001儲存模型權重,mobilenet_v2_1.0_224.ckpt.meta儲存模型圖的流程,mobilenet_v2_1.0_224.ckpt.index儲存模型結構的變數與引數間的索引對應關係,mobilenet_v2_1.0_224_frozen.pb凍結的PB檔案,mobilenet_v2_1.0_224.tflite凍結的tflite檔案。
在Test6_mobilenet下新建目錄pretain_model,將以下三個檔案放入其中
- mobilenet_v2_1.0_224.ckpt.data-00000-of-00001
- mobilenet_v2_1.0_224.ckpt.meta
- mobilenet_v2_1.0_224.ckpt.index
生成預訓練模型程式碼,執行後在當前路徑下生成兩個檔案:
- pretrain_weights.ckpt.data-00000-of-00001
- pretrain_weights.ckpt.index
import tensorflow as tf
def rename_var(ckpt_path, new_ckpt_path, num_classes=5):
with tf.Graph().as_default(), tf.compat.v1.Session().as_default() as sess:
var_list = tf.train.list_variables(ckpt_path)
new_var_list = []
for var_name, shape in var_list:
# print(var_name)
#filter 不需要的層結構
if var_name in except_list:
continue
#filter所有有關優化器info
if "RMSProp" in var_name or "Exponential" in var_name:
continue
#層結構的名稱轉換
var = tf.train.load_variable(ckpt_path, var_name)
new_var_name = var_name.replace('MobilenetV2/', "")
new_var_name = new_var_name.replace("/expand/weights", "/expand/Conv2d/weights")
new_var_name = new_var_name.replace("Conv/weights", "Conv/Conv2d/kernel")
new_var_name = new_var_name.replace("Conv_1/weights", "Conv_1/Conv2d/kernel")
new_var_name = new_var_name.replace("weights", "kernel")
new_var_name = new_var_name.replace("biases", "bias")
first_word = new_var_name.split('/')[0]
if "expanded_conv" in first_word:
last_word = first_word.split('expanded_conv')[-1]
if len(last_word) > 0:
new_word = "inverted_residual" + last_word + "/expanded_conv/"
else:
new_word = "inverted_residual/expanded_conv/"
new_var_name = new_word + new_var_name.split('/', maxsplit=1)[-1]
print(new_var_name)
re_var = tf.Variable(var, name=new_var_name)
new_var_list.append(re_var)
re_var = tf.Variable(tf.keras.initializers.he_uniform()([1280, num_classes]), name="Logits/kernel")
new_var_list.append(re_var)
re_var = tf.Variable(tf.keras.initializers.he_uniform()([num_classes]), name="Logits/bias")
new_var_list.append(re_var)
tf.keras.initializers.he_uniform()
saver = tf.compat.v1.train.Saver(new_var_list)
sess.run(tf.compat.v1.global_variables_initializer())
saver.save(sess, save_path=new_ckpt_path, write_meta_graph=False, write_state=False)
# 不需要的層結構
# 'MobilenetV2/Logits/Conv2d_1c_1x1/biases', 'MobilenetV2/Logits/Conv2d_1c_1x1/weights'
# MobilenetV2的全連線層偏置和權重
except_list = ['global_step', 'MobilenetV2/Logits/Conv2d_1c_1x1/biases', 'MobilenetV2/Logits/Conv2d_1c_1x1/weights']
ckpt_path = './pretain_model/mobilenet_v2_1.0_224.ckpt'
new_ckpt_path = './pretrain_weights.ckpt'
num_classes = 5
rename_var(ckpt_path, new_ckpt_path, num_classes)
相關文章
- TensorFlow 呼叫預訓練好的模型—— Python 實現模型Python
- 使用PaddleFluid和TensorFlow訓練序列標註模型UI模型
- 基於Mindspore2.0的GPT2預訓練模型遷移教程GPT模型
- 飛槳帶你瞭解:基於百科類資料訓練的 ELMo 中文預訓練模型模型
- 使用LSTM模型做股票預測【基於Tensorflow】模型
- TensorFlow2.0教程-使用keras訓練模型Keras模型
- 如何將keras訓練的模型轉換成tensorflow lite模型Keras模型
- PyTorch預訓練Bert模型PyTorch模型
- 人工智慧的預訓練基礎模型的分類人工智慧模型
- 模型訓練:資料預處理和預載入模型
- 自訓練 + 預訓練 = 更好的自然語言理解模型模型
- 火山引擎釋出大模型訓練影片預處理方案,已應用於豆包影片生成模型大模型
- 預約直播 | 基於預訓練模型的自然語言處理及EasyNLP演算法框架模型自然語言處理演算法框架
- 預訓練模型 & Fine-tuning模型
- 【AI】Pytorch_預訓練模型AIPyTorch模型
- 【預訓練語言模型】 使用Transformers庫進行BERT預訓練模型ORM
- MxNet預訓練模型到Pytorch模型的轉換模型PyTorch
- 新型大語言模型的預訓練與後訓練正規化,蘋果的AFM基礎語言模型模型蘋果
- ML2021 | (騰訊)PatrickStar:通過基於塊的記憶體管理實現預訓練模型的並行訓練記憶體模型並行
- 基於飛槳PaddlePaddle的多種影像分類預訓練模型強勢釋出模型
- 基於 Fluid+JindoCache 加速大模型訓練的實踐UI大模型
- 在 C/C++ 中使用 TensorFlow 預訓練好的模型—— 間接呼叫 Python 實現C++模型Python
- 【預訓練語言模型】使用Transformers庫進行GPT2預訓練模型ORMGPT
- 如何評估一個回答的好壞——BERTScore 基於預訓練模型的相似度度量方式模型
- 基於Python和TensorFlow實現BERT模型應用Python模型
- 在 C/C++ 中使用 TensorFlow 預訓練好的模型—— 直接呼叫 C++ 介面實現C++模型
- NLP生成任務超越BERT、GPT!微軟提出通用預訓練模型MASSGPT微軟模型
- tensorflow:一個簡單的python訓練儲存模型,java還原模型方法Python模型Java
- 知識增強的預訓練語言模型系列之ERNIE:如何為預訓練語言模型注入知識模型
- TorchVision 預訓練模型進行推斷模型
- keras中VGG19預訓練模型的使用Keras模型
- 基於Theano的深度學習框架keras及配合SVM訓練模型深度學習框架Keras模型
- 使用Tensorflow Object Detection進行訓練和推理Object
- 新型大語言模型的預訓練與後訓練正規化,谷歌的Gemma 2語言模型模型谷歌Gemma
- 新型大語言模型的預訓練與後訓練正規化,Meta的Llama 3.1語言模型模型
- 新型大語言模型的預訓練與後訓練正規化,阿里Qwen模型阿里
- 生成式預訓練語言模型能否視作閉卷問答的知識庫?模型
- 《深度學習案例精粹:基於TensorFlow與Keras》案例集用於深度學習訓練深度學習Keras