tensorflow2.0 自定義類模組列印問題
明明tensorflow1x就沒有學好,現在換了2,那就再從頭學起叭
在自定義類模組列印遇到了各種bug,在此記錄一下,也希望可以幫助到有需要的夥伴。
如果你自定義類程式碼寫成如下格式
import tensorflow as tf
from tensorflow.keras.models import Model
class SAEModel(Model):
def __init__(self, input_shape, output_shape, hidden_shape=None):
# print("init")
# 隱藏層節點個數預設為輸入層的3倍
if hidden_shape == None:
hidden_shape = 3 * input_shape
# 呼叫父類__init__()方法
super(SAEModel, self).__init__()
# 初始化模型使用的layer,layer_1為前述自定義layer
self.layer_1 = SAELayer(hidden_shape)
# layer_2為全連線層,採用sigmoid啟用函式
# 每層在這裡可以不考慮輸入元素個數,但必須考慮輸出元素個數
# 輸入元素個數可以在call()函式中動態確定
self.layer_2 = layers.Dense(output_shape, activation=tf.nn.sigmoid)
def call(self, input_tensor, training=False):
# 輸入資料
hidden = self.layer_1(input_tensor)
output = self.layer_2(hidden)
return output
那麼會報錯如下,提示你需要使用model.build(input_shape=)方法。但是!我這個input_shape引數我輸入總是出現各種錯誤。
有時列印出來的模型框架也不理想,如下圖
在多方搜尋後,看到了某位大神的程式碼,然後修改自己程式碼,完美輸出。
https://blog.csdn.net/qq_40642546/article/details/106622996
class MyLstm(tf.keras.Model):
def __init__(self, voca_len, out_len, max_len, batch_size, embedding_dim, unit_num, embed_matrix = None):
#==============================================
super().__init__()
# =============================================
self.voca_len = voca_len
self.out_len = out_len
self.max_len = max_len
self.batch_size = batch_size
self.embedding_dim = embedding_dim
self.unit_num = unit_num
self.embed_matrix = None
self.input_layer = Input(max_len,
name = 'inputs')
if self.embed_matrix == None:
self.embedding_layer = Embedding(input_dim = self.voca_len,
output_dim = self.embedding_dim,
input_length = self.max_len,
trainable = True,
name = 'embedding')
else:
self.embedding_layer = Embedding(input_dim = self.voca_len,
output_dim = self.embedding_dim,
weights = [self.embed_matrix],
input_length = self.max_len,
name = 'embedding')
self.lstm_layer = LSTM(units=self.unit_num,
activation = 'relu',
name = 'LSTM')
self.out_layer = Dense(units=self.out_len,
name = 'Train_out')
self.out = self.call(self.input_layer) # !!!!!
# ================================================
super().__init__(inputs=self.input_layer,
outputs=self.out)
#=================================================
def build(self):
self.is_graph_network = True
self.__init__graph_network(inputs=self.input_layer,
outputs = self.out)
def call(self, x, from_logits=False, training=True, mask=None):
embed_matrix = self.embedding_layer(x)
# print(embed_matrix)
lstm_output= self.lstm_layer(embed_matrix)
logits = self.out_layer(lstm_output)
if from_logits:
return logits
else:
return tf.nn.sigmoid(logits)
model = MyLstm(voca_len=394, out_len=504, max_len=29, embedding_dim=23, unit_num=100, batch_size=2)
model.summary()
關鍵的地方在上面的程式碼中做了註釋,兩次初始化很重要,至於為啥這麼寫。。。
emmmmm 我先會用叭
相關文章
- struts 自定義validate 問題
- ansible自定義模組
- laravel自定義命令列印進度條Laravel命令列
- 中文自定義字型列印解決!(轉)自定義字型
- iOS 自定義字型出問題啦!iOS自定義字型
- 自定義標籤出現問題
- 自定義異常類
- 自定義View的硬體加速問題View
- Zepto自定義模組打包構建
- 優化自定義的Exception的日誌列印,設定自定義使用的ApiException extends Exception日誌列印不刷出堆疊資訊配置程式碼ApiException類即可優化ExceptionAPI
- Flutter實戰之自定義日誌列印元件Flutter元件
- Laravel 自定義驗證規則的問題Laravel
- flume自定義攔截器遇到的問題
- 在自定義View時碰到的奇怪問題View
- 自定義RedisTemplate,解決Redis亂碼問題Redis
- js模組化之自定義模組(頁面模組化載入)JS
- 分享 vxe-table 實現列印出貨單、自定義列印單據
- C++自定義貪吃蛇Snake類一系列問題的解決C++
- python - 建立一個自定義模組Python
- python如何匯入自定義模組Python
- 在Python中新增自定義模組Python
- freeswitch自定義模組的wiki地址
- 第十章 自定義模組
- TensorFlow2.0教程-文字分類文字分類
- SpringBoot自定義註解、AOP列印日誌Spring Boot
- MySQL自定義變數處理行號問題MySql變數
- POWERBUILDER KODIGO 框架 自定義透明圖片問題UIGo框架
- Django(62)自定義認證類Django
- 自定義實現Complex類
- Java的自定義異常類Java
- 工具類——自定義Collections集合方法
- Go 模組存在的意義與解決的問題Go
- Python 日誌列印之自定義logger handlerPython
- WinForm 載入自定義控制元件閃爍問題ORM控制元件
- Spring Boot(3)---自定義spring boot starter 問題Spring Boot
- 急急急急!Struts自定義標籤html:text 問題HTML
- Laravel自定義Make命令生成Service類Laravel
- Java自定義一個字典類(Dictionary)Java