TensorFlow中滑動平均模型介紹
內容移至:https://oldpan.me/archives/tensorflow-movingaverage
內容總結於《TensorFlow實戰Google深度學習框架》
不知道大家有沒有聽過一階滯後濾波法:
其中a的取值範圍[0,1],具體就是:本次濾波結果=(1-a)本次取樣值+a上次濾波結果,採用此演算法的目的是:
1、降低週期性的干擾;
2、在波動頻率較高的場合有很好的效果。
而在TensorFlow中提供了tf.train.ExponentialMovingAverage
來實現滑動平均模型,在採用隨機梯度下降演算法訓練神經網路時,使用其可以提高模型在測試資料上的健壯性(robustness)。
TensorFlow下的 tf.train.ExponentialMovingAverage
需要提供一個衰減率decay。該衰減率用於控制模型更新的速度。該衰減率用於控制模型更新的速度,ExponentialMovingAverage 對每一個待更新的變數(variable)都會維護一個影子變數(shadow variable)。影子變數的初始值就是這個變數的初始值,
上述公式與之前介紹的一階滯後濾波法的公式相比較,會發現有很多相似的地方,從名字上面也可以很好的理解這個簡約不簡單演算法的原理:平滑、濾波,即使資料平滑變化,通過調整引數來調整變化的穩定性。
在滑動平滑模型中, decay 決定了模型更新的速度,越大越趨於穩定。實際運用中,decay 一般會設定為十分接近 1 的常數(0.999或0.9999)。為了使得模型在訓練的初始階段更新得更快,ExponentialMovingAverage 還提供了 num_updates 引數來動態設定 decay 的大小:
用一段書中程式碼帶解釋如何使用滑動平均模型:
import tensorflow as tf
v1 = tf.Variable(0, dtype=tf.float32)//初始化v1變數
step = tf.Variable(0, trainable=False) //初始化step為0
ema = tf.train.ExponentialMovingAverage(0.99, step) //定義平滑類,設定引數以及step
maintain_averages_op = ema.apply([v1]) //定義更新變數平均操作
with tf.Session() as sess:
# 初始化
init_op = tf.global_variables_initializer()
sess.run(init_op)
print sess.run([v1, ema.average(v1)])
# 更新變數v1的取值
sess.run(tf.assign(v1, 5))
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
# 更新step和v1的取值
sess.run(tf.assign(step, 10000))
sess.run(tf.assign(v1, 10))
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
# 更新一次v1的滑動平均值
sess.run(maintain_averages_op)
print sess.run([v1, ema.average(v1)])
output:
[0.0, 0.0]
[5.0, 4.5]
[10.0, 4.5549998]
[10.0, 4.6094499]
相關文章
- TensorFlow筆記-07-神經網路優化-學習率,滑動平均筆記神經網路優化
- 推薦模型NeuralCF:原理介紹與TensorFlow2.0實現模型
- 推薦模型DeepCrossing: 原理介紹與TensorFlow2.0實現模型ROS
- 理解滑動平均(exponential moving average)
- Linux中的IO模型介紹Linux模型
- Tensorflow介紹和安裝
- 滑動視窗(Sliding Window)演算法介紹演算法
- Tensorflow教程(2)Tensorflow的常用函式介紹函式
- 8.3 BERT模型介紹模型
- 關於風機滑環的介紹
- 簡單介紹android實現可以滑動的平滑曲線圖Android
- Django 2.0 模型層中 QuerySet 查詢操作介紹Django模型
- 平均和最壞時間複雜度介紹時間複雜度
- [譯]寫給初學者的Tensorflow介紹[2]
- RBAC_許可權模型介紹模型
- ChatGPT-4o模型功能介紹ChatGPT模型
- 網路 IO 模型簡單介紹模型
- UI自動化學習筆記- PO模型介紹和使用UI筆記模型
- 專案管理的四大模型-瀑布模型介紹專案管理大模型
- 五種IO模型介紹和對比模型
- 決策樹模型(1)總體介紹模型
- Qt 檔案模型(QFileSystemModel)詳細介紹QT模型
- css盒子模型的屬性介紹CSS模型
- 模型預處理層介紹(1) - Discretization模型
- 第71篇 Dto與多模型介紹模型
- 資料倉儲 - 星座模型、星型模型和雪花模型的介紹模型
- 簡單介紹TensorFlow中關於tf.app.flags命令列引數解析模組APP命令列
- 編譯 TensorFlow 模型編譯模型
- 用Tensorflow2.0實現Faster-RCNN的程式碼介紹ASTCNN
- 三維點雲語義分割模型介紹模型
- 簡單的介紹 Eloquent 模型生命週期模型
- 敏捷轉型ADKAR變革管理模型介紹敏捷模型
- SAP CRM附件模型的Authorization scope原理介紹模型
- 常用的模型整合方法介紹:bagging、boosting 、stacking模型
- 大型語言模型(Large Language Models)的介紹模型
- 應用模型開發指南上新介紹模型
- Bootstrap Blazor 元件介紹 Table (一)自動生成列功能介紹bootBlazor元件
- 【Tensorflow_DL_Note12】TensorFlow中LeNet-5模型的實現程式碼模型
- 【TensorFlow】 TensorFlow-Slim影像分類模型庫模型