tensorflow學習筆記keras(5)------北京大學 曹健
tf.keras搭建網路八股
1. import
import tensorflow as tf
from sklearn import datasets
import numpy as np
2. train,test
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
3. tf.keras.models.Sequential(逐層描述網路結構,走一邊前向傳播)
Sequential 函式是一個容器,描述了神經網路的網路結構,在 Sequential函式的輸入引數中描述從輸入層到輸出層的網路結構。
如:
拉直層:tf.keras.layers.Flatten()
拉直層可以變換張量的尺寸,把輸入特徵拉直為一維陣列,是不含計算引數的層。
全連線層:tf.keras.layers.Dense( 神經元個數,
activation=”啟用函式”,
kernel_regularizer=”正則化方式”)
其中:
activation(字串給出)可選
relu、softmax、sigmoid、tanh 等
kernel_regularizer 可選
tf.keras.regularizers.l1() L1正則化、
tf.keras.regularizers.l2() L2正則化
卷積層:tf.keras.layers.Conv2D( filter = 卷積核個數,kernel_size = 卷積核尺寸,strides = 卷積步長,padding = “valid” or “same”)
迴圈層:LSTM 層:tf.keras.layers.LSTM()
舉例:
model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(3, activation="relu",
kernel_regularizer=tf.keras.regularizers.l2())])
採用全連線層,3表示輸入神經元個數,啟用函式用relu,第三個引數用l2正則化。(易錯點,選擇啟用函式可以用字串或tf.nn.relu; 還有l2正則化後記得加括號,否則報錯)
4. compile(訓練方法的配置,優化器、損失函式、評測指標)
Model.compile( optimizer = 優化器, loss = 損失函式, metrics = [“準確率”])
- 優化器:
‘sgd’or tf.optimizers.SGD( lr=學習率,decay=學習率衰減率,
momentum=動量引數)
‘adagrad’or tf.keras.optimizers.Adagrad(lr=學習率,
decay=學習率衰減率)
‘adadelta’or tf.keras.optimizers.Adadelta(lr=學習率,
decay=學習率衰減率)
‘adam’or tf.keras.optimizers.Adam (lr=學習率,
decay=學習率衰減率)
- 損失函式:
‘mse’or tf.keras.losses.MeanSquaredError()
‘sparse_categorical_crossentropy
or tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
- metrics:
‘accuracy’:y_和 y 都是數值,如 y_=[1] y=[1]。
‘categorical_accuracy’:y_和 y 都是以獨熱碼和概率分佈表示。
如 y_=[0, 1, 0], y=[0.256, 0.695, 0.048]。
‘sparse_ categorical_accuracy’:y_是以數值形式給出,y 是以獨熱碼形式
給出。 如 y_=[1],y=[0.256, 0.695, 0.048]。
通常用第三種。
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['sparse_categorical_accuracy'])
(優化器加括號,否則報錯,此外括號內可設定學習率等引數,)
5. fit(輸入特徵和標籤,batch,迭代次數)
model.fit(訓練集的輸入特徵, 訓練集的標籤, batch_size, epochs,
validation_data = (測試集的輸入特徵,測試集的標籤),
validataion_split = 從測試集劃分多少比例給訓練集,
validation_freq = 測試的 epoch 間隔次數)
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
6. summary(列印)
7. 總得
import tensorflow as tf
from sklearn import datasets
import numpy as np
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)
model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(3, activation=tf.nn.relu,
kernel_regularizer=tf.keras.regularizers.l2())])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()
相關文章
- 深度學習keras筆記深度學習Keras筆記
- TensorFlow 學習筆記筆記
- 【北京大學】人工智慧實踐:Tensorflow筆記(一)人工智慧筆記
- TensorFlow學習筆記(二)筆記
- tensorflow學習筆記3筆記
- tensorflow學習筆記——DenseNet筆記SENet
- TensorFlow Java API 學習筆記JavaAPI筆記
- Tensorflow學習筆記No.7筆記
- Tensorflow學習筆記No.8筆記
- Tensorflow學習筆記No.10筆記
- Tensorflow學習筆記No.11筆記
- tensorflow學習筆記--embedding_lookup()用法筆記
- Vue學習筆記5Vue筆記
- 強化學習-學習筆記5 | AlphaGo強化學習筆記Go
- spring-5學習筆記Spring筆記
- HTML5學習筆記HTML筆記
- [學習筆記 #5] 雜湊筆記
- Tensorflow學習筆記: 變數及共享變數筆記變數
- H5學習筆記(一)H5筆記
- linux學習筆記-day5Linux筆記
- 比特幣學習筆記——————5、 交易比特幣筆記
- swift學習筆記《5》- 實用Swift筆記
- G01學習筆記-5筆記
- tensorflow學習筆記1——mac開發環境配置筆記Mac開發環境
- AI學習筆記——Tensorflow中的Optimizer(優化器)AI筆記優化
- TensorFlow、Keras、CNTK...到底哪種深度學習框架更好用?Keras深度學習框架
- 《深度學習案例精粹:基於TensorFlow與Keras》案例集用於深度學習訓練深度學習Keras
- numpy的學習筆記\pandas學習筆記筆記
- 機器學習框架ML.NET學習筆記【6】TensorFlow圖片分類機器學習框架筆記
- Camera KMD ISP學習筆記(5)-DRQ筆記
- OpenCV學習筆記(5)——normalize函式OpenCV筆記ORM函式
- Flutter學習筆記(5)--Dart運算子Flutter筆記Dart
- 深度學習筆記8:利用Tensorflow搭建神經網路深度學習筆記神經網路
- 【Redis學習筆記】2018-07-11 Redis指令學習5Redis筆記
- 讀書筆記(四):深度學習基於Keras的Python實踐筆記深度學習KerasPython
- Python機器學習筆記:使用Keras進行迴歸預測Python機器學習筆記Keras
- Scikit-Learn 與 TensorFlow 機器學習實用指南學習筆記 5 —— 如何為機器學習演算法準備資料?機器學習筆記演算法
- 【學習筆記】數學筆記