tensorflow2.0 自定義類模組列印問題
明明tensorflow1x就沒有學好,現在換了2,那就再從頭學起叭
在自定義類模組列印遇到了各種bug,在此記錄一下,也希望可以幫助到有需要的夥伴。
如果你自定義類程式碼寫成如下格式
import tensorflow as tf
from tensorflow.keras.models import Model
class SAEModel(Model):
def __init__(self, input_shape, output_shape, hidden_shape=None):
# print("init")
# 隱藏層節點個數預設為輸入層的3倍
if hidden_shape == None:
hidden_shape = 3 * input_shape
# 呼叫父類__init__()方法
super(SAEModel, self).__init__()
# 初始化模型使用的layer,layer_1為前述自定義layer
self.layer_1 = SAELayer(hidden_shape)
# layer_2為全連線層,採用sigmoid啟用函式
# 每層在這裡可以不考慮輸入元素個數,但必須考慮輸出元素個數
# 輸入元素個數可以在call()函式中動態確定
self.layer_2 = layers.Dense(output_shape, activation=tf.nn.sigmoid)
def call(self, input_tensor, training=False):
# 輸入資料
hidden = self.layer_1(input_tensor)
output = self.layer_2(hidden)
return output
那麼會報錯如下,提示你需要使用model.build(input_shape=)方法。但是!我這個input_shape引數我輸入總是出現各種錯誤。
有時列印出來的模型框架也不理想,如下圖
在多方搜尋後,看到了某位大神的程式碼,然後修改自己程式碼,完美輸出。
https://blog.csdn.net/qq_40642546/article/details/106622996
class MyLstm(tf.keras.Model):
def __init__(self, voca_len, out_len, max_len, batch_size, embedding_dim, unit_num, embed_matrix = None):
#==============================================
super().__init__()
# =============================================
self.voca_len = voca_len
self.out_len = out_len
self.max_len = max_len
self.batch_size = batch_size
self.embedding_dim = embedding_dim
self.unit_num = unit_num
self.embed_matrix = None
self.input_layer = Input(max_len,
name = 'inputs')
if self.embed_matrix == None:
self.embedding_layer = Embedding(input_dim = self.voca_len,
output_dim = self.embedding_dim,
input_length = self.max_len,
trainable = True,
name = 'embedding')
else:
self.embedding_layer = Embedding(input_dim = self.voca_len,
output_dim = self.embedding_dim,
weights = [self.embed_matrix],
input_length = self.max_len,
name = 'embedding')
self.lstm_layer = LSTM(units=self.unit_num,
activation = 'relu',
name = 'LSTM')
self.out_layer = Dense(units=self.out_len,
name = 'Train_out')
self.out = self.call(self.input_layer) # !!!!!
# ================================================
super().__init__(inputs=self.input_layer,
outputs=self.out)
#=================================================
def build(self):
self.is_graph_network = True
self.__init__graph_network(inputs=self.input_layer,
outputs = self.out)
def call(self, x, from_logits=False, training=True, mask=None):
embed_matrix = self.embedding_layer(x)
# print(embed_matrix)
lstm_output= self.lstm_layer(embed_matrix)
logits = self.out_layer(lstm_output)
if from_logits:
return logits
else:
return tf.nn.sigmoid(logits)
model = MyLstm(voca_len=394, out_len=504, max_len=29, embedding_dim=23, unit_num=100, batch_size=2)
model.summary()
關鍵的地方在上面的程式碼中做了註釋,兩次初始化很重要,至於為啥這麼寫。。。
emmmmm 我先會用叭
相關文章
- ansible自定義模組
- ReactNative自定義NetworkingModule網路模組React
- Zepto自定義模組打包構建
- python - 建立一個自定義模組Python
- python如何匯入自定義模組Python
- 第十章 自定義模組
- Go 模組存在的意義與解決的問題Go
- python基礎--自定義模組、import、from......import......PythonImport
- Python學習之如何引用Python自定義模組?Python
- laravel自定義命令列印進度條Laravel命令列
- vxe-table 列印出貨單、自定義列印單據
- ??Java開發者的Python快速進修指南:自定義模組及常用模組JavaPython
- Magento 後臺 Configuration 下建立新的自定義模組
- 自定義View的硬體加速問題View
- 自定義異常類
- SpringBoot自定義註解、AOP列印日誌Spring Boot
- Python 日誌列印之自定義logger handlerPython
- java 自定義表單 掛靠流程 模組設計方案Java
- HTMLTestRunnerNew模組原始碼及呼叫自定義報告封裝HTML原始碼封裝
- vxe-table 實現列印出貨單、自定義單據列印
- 自定義RedisTemplate,解決Redis亂碼問題Redis
- flume自定義攔截器遇到的問題
- Laravel 自定義驗證規則的問題Laravel
- C++自定義貪吃蛇Snake類一系列問題的解決C++
- Flutter實戰之自定義日誌列印元件Flutter元件
- 自定義實現Complex類
- Python如何自定義元類Python
- Python3中如何做的自定義模組的引用?Python
- freeswitch修改mod_sofia模組並上報自定義頭域
- 優化自定義的Exception的日誌列印,設定自定義使用的ApiException extends Exception日誌列印不刷出堆疊資訊配置程式碼ApiException類即可優化ExceptionAPI
- 分享 vxe-table 實現列印出貨單、自定義列印單據
- 工作流 自定義表單 掛靠流程 模組設計方案
- TensorFlow2.0教程-文字分類文字分類
- Django(62)自定義認證類Django
- drozer模組的編寫及模組動態載入問題研究
- 06 ## 模組分類
- WinForm 載入自定義控制元件閃爍問題ORM控制元件
- 純 CSS 解決自定義 CheckBox 背景顏色問題CSS