【新聞】:機器學習煉丹術的粉絲的人工智慧交流群已經建立,目前有目標檢測、醫學影像、時間序列等多個目標為技術學習的分群和水群嘮嗑的總群,歡迎大家加煉丹兄為好友,加入煉丹協會。微信:cyx645016617.
參考目錄:
本文主要講述TF2.0的模型檔案的儲存和載入的多種方法。主要分成兩型別:模型結構和引數一起載入,模型的結構載入。
1 模型的構建
import tensorflow.keras as keras
class CBR(keras.layers.Layer):
def __init__(self,output_dim):
super(CBR,self).__init__()
self.conv = keras.layers.Conv2D(filters=output_dim, kernel_size=4, padding='same', strides=1)
self.bn = keras.layers.BatchNormalization(axis=3)
self.ReLU = keras.layers.ReLU()
def call(self, inputs):
inputs = self.conv(inputs)
inputs = self.ReLU(self.bn(inputs))
return inputs
class MyNet(keras.Model):
def __init__ (self):
super(MyNet,self).__init__()
self.cbr1 = CBR(16)
self.maxpool1 = keras.layers.MaxPool2D(pool_size=(2,2))
self.cbr2 = CBR(32)
self.maxpool2 = keras.layers.MaxPool2D(pool_size=(2,2))
def call(self, inputs):
inputs = self.maxpool1(self.cbr1(inputs))
inputs = self.maxpool2(self.cbr2(inputs))
return inputs
model = MyNet()
部分朋友可以發現,上面的程式碼就是上一次課程所構建的一個自定義的網路。
我們現在需要展示這個模型的框架:
model.build((16,224,224,3))
print(model.summary())
執行結果為:
這裡需要對網路執行一個構建.build()
函式,之後才能生成model.summary()
這樣的模型的描述。 這是因為模型的引數量是需要知道輸入資料的通道數的,假如我們輸入的是單通道的圖片,那麼就是:
model.build((16,224,224,1))
print(model.summary())
輸出結果為:
2 結構引數的儲存與載入
model.save('save_model.h5')
new_model = keras.models.load_model('save_model.h5')
這裡並不能儲存成功,出現這樣的錯誤:
大概的意思就是:因為你的模型不是官方的模型,是自定義的,所以並不能同時儲存結構和引數。只有官方的模型可以時候上面的儲存的方法,同時儲存引數和權重;自定義的模型建議只儲存引數
3 引數的儲存與載入
model.save_weights('model_weight')
new_model = MyNet()
new_model.load_weights('model_weight')
這樣子就可以儲存自定義的模型了。在對應的目錄下會出現這幾個檔案:
我們來看一下原來的模型和載入的模型對於同一個樣本給出的結果是否相同:
# 看一下原來的模型和載入的模型預測相同的樣本的輸出
test = tf.ones((1,8,8,3))
prediction = model.predict(test)
new_prediction = new_model.predict(test)
print(prediction,new_prediction)
>>> [[[[0.02559286]]]] [[[[0.02559286]]]]
結果相同,載入的沒有問題~
4 結構的儲存與載入
結構的儲存有兩種方法:
model.get_config()
model.to_json()
需要注意的是,上面的兩個方法和save的問題一樣,是不能用在自定義的模型中的,如果你在其中使用了自定義的Layer類,那麼只能!只能用save_weights的方式進行儲存
下面依然給出這兩種方法的程式碼,對於簡單的、已經封裝好的一些網路層構成的網路,是可以使用這些的。我個人還是常用save_weights啦
# 第一種方法
config = model.get_config()
reinitialized_model = keras.Model.from_config(config)
# 第二種方法
json_config = model.to_json()
# 把json寫的檔案中
with open('model_config.json', 'w') as json_file:
json_file.write(json_config)
# 讀取本地json檔案
with open('model_config.json') as json_file:
json_config = json_file.read()
reinitialized_model = keras.models.model_from_json(json_config)
今天的內容就是這麼多,雖然提供了四種方法,但是對於自定義程度較高的模型,還是要使用save_weights哦~