深度殘差收縮網路:(六)程式碼實現
深度殘差收縮網路其實是一種通用的特徵學習方法,是深度殘差網路ResNet、注意力機制和軟閾值化的整合,可以用於影像分類。本文采用TensorFlow 1.0和TFLearn 0.3.2,編寫了影像分類的程式,採用的影像資料為 CIFAR-10 。 CIFAR-10是一個非常常用的影像資料集,包含10個類別的影像。可以在這個網址找到具體介紹:
TFLearn是TensorFlow的一個高層API,很方便初學者使用。本文參照了TFLearn的ResNet案例( https://github.com/tflearn/tflearn/blob/master/examples/images/residual_network_cifar10.py ),所編寫的深度殘差收縮網路的程式碼如下:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Mon Dec 23 21:23:09 2019 M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898 @author: super_9527 """ from __future__ import division, print_function, absolute_import import tflearn import numpy as np import tensorflow as tf from tflearn.layers.conv import conv_2d # Data loading from tflearn.datasets import cifar10 (X, Y), (testX, testY) = cifar10.load_data() # Add noise X = X + np.random.random((50000, 32, 32, 3))*0.1 testX = testX + np.random.random((10000, 32, 32, 3))*0.1 # Transform labels to one-hot format Y = tflearn.data_utils.to_categorical(Y,10) testY = tflearn.data_utils.to_categorical(testY,10) def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False, downsample_strides=2, activation='relu', batch_norm=True, bias=True, weights_init='variance_scaling', bias_init='zeros', regularizer='L2', weight_decay=0.0001, trainable=True, restore=True, reuse=False, scope=None, name="ResidualBlock"): # residual shrinkage blocks with channel-wise thresholds residual = incoming in_channels = incoming.get_shape().as_list()[-1] # Variable Scope fix for older TF try: vscope = tf.variable_scope(scope, default_name=name, values=[incoming], reuse=reuse) except Exception: vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse) with vscope as scope: name = scope.name #TODO for i in range(nb_blocks): identity = residual if not downsample: downsample_strides = 1 if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, downsample_strides, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) if batch_norm: residual = tflearn.batch_normalization(residual) residual = tflearn.activation(residual, activation) residual = conv_2d(residual, out_channels, 3, 1, 'same', 'linear', bias, weights_init, bias_init, regularizer, weight_decay, trainable, restore) # get thresholds and apply thresholding abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True) scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tflearn.batch_normalization(scales) scales = tflearn.activation(scales, 'relu') scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling') scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1) thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales)) # soft thresholding residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0)) # Downsampling if downsample_strides > 1: identity = tflearn.avg_pool_2d(identity, 1, downsample_strides) # Projection to new dimension if in_channels != out_channels: if (out_channels - in_channels) % 2 == 0: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch]]) else: ch = (out_channels - in_channels)//2 identity = tf.pad(identity, [[0, 0], [0, 0], [0, 0], [ch, ch+1]]) in_channels = out_channels residual = residual + identity return residual # Real-time data preprocessing img_prep = tflearn.ImagePreprocessing() img_prep.add_featurewise_zero_center(per_channel=True) # Real-time data augmentation img_aug = tflearn.ImageAugmentation() img_aug.add_random_flip_leftright() img_aug.add_random_crop([32, 32], padding=4) # Building Deep Residual Shrinkage Network net = tflearn.input_data(shape=[None, 32, 32, 3], data_preprocessing=img_prep, data_augmentation=img_aug) net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001) net = residual_shrinkage_block(net, 1, 16) net = residual_shrinkage_block(net, 1, 32, downsample=True) net = residual_shrinkage_block(net, 1, 32, downsample=True) net = tflearn.batch_normalization(net) net = tflearn.activation(net, 'relu') net = tflearn.global_avg_pool(net) # Regression net = tflearn.fully_connected(net, 10, activation='softmax') mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True) net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy') # Training model = tflearn.DNN(net, checkpoint_path='model_cifar10', max_checkpoints=10, tensorboard_verbose=0, clip_gradients=0.) model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500, show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10') training_acc = model.evaluate(X, Y)[0] validation_acc = model.evaluate(testX, testY)[0]
上面的程式碼構建了一個小型的深度殘差收縮網路,只含有3個基本殘差收縮模組,其他的超引數也未進行最佳化。如果為了追求更高的準確率的話,可以適當增加深度,增加訓練迭代次數,以及適當調整超引數。
前四篇的內容:
深度殘差收縮網路:(一)背景知識 https://www.cnblogs.com/yc-9527/p/11598844.html
深度殘差收縮網路:(二)整體思路 https://www.cnblogs.com/yc-9527/p/11601322.html
深度殘差收縮網路:(三)網路結構 https://www.cnblogs.com/yc-9527/p/11603320.html
深度殘差收縮網路:(四)注意力機制下的閾值設定 https://www.cnblogs.com/yc-9527/p/11604082.html
深度殘差收縮網路:(五)實驗驗證 https://www.cnblogs.com/yc-9527/p/11610073.html
原文的連結:
M. Zhao, S. Zhong, X. Fu, B. Tang, and M. Pecht, “Deep Residual Shrinkage Networks for Fault Diagnosis,” IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
https://ieeexplore.ieee.org/document/8850096
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69957925/viewspace-2670367/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 殘差網路再升級之深度殘差收縮網路(附Keras程式碼)Keras
- 深度殘差收縮網路:(五)實驗驗證
- 深度殘差收縮網路:(三)網路結構
- 深度殘差收縮網路:(一)背景知識
- 深度殘差收縮網路:(二)整體思路
- 深度學習故障診斷——深度殘差收縮網路深度學習
- 十分鐘弄懂深度殘差收縮網路
- 深度殘差網路(ResNet)
- 深度殘差收縮網路:(四)注意力機制下的閾值設定
- 深度學習之殘差網路深度學習
- 深度三維殘差神經網路:視訊理解新突破神經網路
- 殘差神經網路-ResNet神經網路
- 深度學習——手動實現殘差網路ResNet 辛普森一家人物識別深度學習
- 學習筆記16:殘差網路筆記
- jquery實現的下拉和收縮程式碼例項jQuery
- 殘差網路(Residual Networks, ResNets)
- PyTorch入門-殘差卷積神經網路PyTorch卷積神經網路
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄1)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄2)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄3)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄4)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄5)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄6)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄7)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄8)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄9)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄10)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄11)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄12)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄13)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄14)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄15)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄16)函式
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄17)函式
- 點選標題實現內容元素伸展和收縮程式碼例項
- 從AlexNet到殘差網路,理解卷積神經網路的不同架構卷積神經網路架構
- 影像分割論文 | DRN膨脹殘差網路 | CVPR2017
- 深度殘差網路+自適應引數化ReLU啟用函式(調參記錄26)Cifar10~95.92%函式