十分鐘弄懂深度殘差收縮網路

蒼海一慄發表於2020-01-28

深度殘差網路獲得了2016年CVPR會議的最佳論文獎,截至目前,在谷歌學術上的引用量已經達到了驚人的38295次。

深度殘差收縮網路是深度殘差網路的一種新穎的改進版本,其實是深度殘差網路、注意力機制和軟閾值函式的深度整合。

在一定程度上,深度殘差收縮網路的工作模式,可以理解為:透過注意力機制注意到不重要的特徵,透過軟閾值函式將它們置為零;或者說,透過注意力機制注意到重要的特徵,將它們保留下來,從而加強深度神經網路從含噪聲訊號中提取有用特徵的能力。

1.提出深度殘差收縮網路的動機是什麼呢?

首先,在對樣本進行分類的時候,樣本中不可避免地會有一些噪聲,就像高斯噪聲、粉色噪聲、拉普拉斯噪聲等。更廣義地講,樣本中很可能包含著與當前分類任務無關的資訊,這些資訊也可以理解為噪聲。這些噪聲可能會對分類效果產生不利的影響。(軟閾值化是許多訊號降噪演算法中的一個關鍵步驟)

舉例來說,在馬路邊聊天的時候,聊天的聲音裡就可能會混雜車輛的鳴笛聲、車輪聲等等。當對這些聲音訊號進行語音識別的時候,識別效果不可避免地會受到鳴笛聲、車輪聲的影響。從深度學習的角度來講,這些鳴笛聲、車輪聲所對應的特徵,就應該在深度神經網路內部被刪除掉,以避免對語音識別的效果造成影響。

其次,即使是同一個樣本集,各個樣本的噪聲量也往往是不同的。(這和注意力機制有相通之處;以一個影像樣本集為例,各張圖片中目標物體所在的位置可能是不同的;注意力機制可以針對每一張圖片,注意到目標物體所在的位置)

例如,當訓練貓狗分類器的時候,對於標籤為“狗”的5張影像,第1張影像可能同時包含著狗和老鼠,第2張影像可能同時包含著狗和鵝,第3張影像可能同時包含著狗和雞,第4張影像可能同時包含著狗和驢,第5張影像可能同時包含著狗和鴨子。我們在訓練貓狗分類器的時候,就不可避免地會受到老鼠、鵝、雞、驢和鴨子等無關物體的干擾,造成分類準確率下降。如果我們能夠注意到這些無關的老鼠、鵝、雞、驢和鴨子,將它們所對應的特徵刪除掉,就有可能提高貓狗分類器的準確率。

2.軟閾值化是很多降噪演算法的核心步驟

軟閾值化,是很多訊號降噪演算法的核心步驟,將絕對值小於某個閾值的特徵刪除掉,將絕對值大於這個閾值的特徵朝著零的方向進行收縮。軟閾值化的導數要麼是1,要麼是0。這個性質是和ReLU啟用函式是相同的。因此,軟閾值化也能夠減小深度學習演算法遭遇梯度彌散和梯度爆炸的風險。

在軟閾值化函式中,閾值的設定必須符合兩個的條件: 第一,閾值是正數;第二,閾值不能大於輸入訊號的最大值,否則輸出會全部為零。

同時,閾值最好還能符合第三個條件:每個樣本應該根據自身的噪聲含量,有著自己獨立的閾值。

這是因為,很多樣本的噪聲含量經常是不同的。例如經常會有這種情況,在同一個樣本集裡面,樣本A所含噪聲較少,樣本B所含噪聲較多。那麼,如果是在降噪演算法裡進行軟閾值化的時候,樣本A就應該採用較大的閾值,樣本B就應該採用較小的閾值。在深度神經網路中,雖然這些特徵和閾值失去了明確的物理意義,但是基本的道理還是相通的。也就是說,每個樣本應該根據自身的噪聲含量,有著自己獨立的閾值。

3.注意力機制

注意力機制在計算機視覺領域是比較容易理解的。動物的視覺系統可以快速掃描全部區域,發現目標物體,進而將注意力集中在目標物體上,以提取更多的細節,同時抑制無關資訊。具體請參照注意力機制方面的文章。

Squeeze-and-Excitation Network(SENet)是一種較新的注意力機制下的深度學習方法。 在不同的樣本中,不同的特徵通道,在分類任務中的貢獻大小,往往是不同的。SENet採用一個小型的子網路,獲得一組權重,進而將這組權重與各個通道的特徵分別相乘,以調整各個通道特徵的大小。這個過程,就可以認為是在施加不同大小的注意力在各個特徵通道上。

up-9cfe5dbcf98cf930ed6bfa63e3e053dac37.png

在這種方式下,每一個樣本,都會有自己獨立的一組權重。換言之,任意的兩個樣本,它們的權重,都是不一樣的。在SENet中,獲得權重的具體路徑是,“全域性池化→全連線層→ReLU函式→全連線層→Sigmoid函式”。

up-7d31fdae7e1b7ce709d18b28e3d0826f0e1.png

4.深度注意力機制下的軟閾值化

深度殘差收縮網路借鑑了上述SENet的子網路結構,以實現注意力機制下的軟閾值化。透過藍色框內的子網路,就可以學習得到一組閾值,對各個特徵通道進行軟閾值化。

up-355606b62c2b218f2457c2ac2d2e110fd67.png

在這個子網路中,首先對輸入特徵圖的所有特徵,求它們的絕對值。然後經過全域性均值池化和平均,獲得一個特徵,記為A。在另一條路徑中,全域性均值池化之後的特徵圖,被輸入到一個小型的全連線網路。這個全連線網路以Sigmoid函式作為最後一層,將輸出歸一化到0和1之間,獲得一個係數,記為α。最終的閾值可以表示為α×A。因此,閾值就是,一個0和1之間的數字×特徵圖的絕對值的平均。透過這種方式,保證了閾值為正,而且不會太大。

而且,不同的樣本就有了不同的閾值。因此,在一定程度上,可以理解成一種特殊的注意力機制:注意到與當前任務無關的特徵,透過軟閾值化,將它們置為零;或者說,注意到與當前任務有關的特徵,將它們保留下來。

最後,堆疊一定數量的基本模組以及卷積層、批標準化、啟用函式、全域性均值池化以及全連線輸出層等,就得到了完整的深度殘差收縮網路。

up-e4b3d67328e4eb03b4c69f40fdad245003f.png

5.深度殘差收縮網路或許有更廣泛的通用性

深度殘差收縮網路事實上是一種通用的特徵學習方法。這是因為很多特徵學習的任務中,樣本中或多或少都會包含一些噪聲,以及不相關的資訊。這些噪聲和不相關的資訊,有可能會對特徵學習的效果造成影響。例如說:

在圖片分類的時候,如果圖片同時包含著很多其他的物體,那麼這些物體就可以被理解成“噪聲”;深度殘差收縮網路或許能夠藉助注意力機制,注意到這些“噪聲”,然後藉助軟閾值化,將這些“噪聲”所對應的特徵置為零,就有可能提高影像分類的準確率。

在語音識別的時候,如果在聲音較為嘈雜的環境裡,比如在馬路邊、工廠車間裡聊天的時候,深度殘差收縮網路也許可以提高語音識別的準確率,或者給出了一種能夠提高語音識別準確率的思路。

6.Keras和TFLearn程式簡介

本程式以影像分類為例,構建了小型的深度殘差收縮網路,超引數也未進行最佳化。為追求高準確率的話,可以適當增加深度,增加訓練迭代次數,以及適當調整超引數。下面是Keras程式:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sat Dec 28 23:24:05 2019
Implemented using TensorFlow 1.0.1 and Keras 2.2.1
 
M. Zhao, S. Zhong, X. Fu, et al., Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
@author: super_9527
"""
from __future__ import print_function
import keras
import numpy as np
from keras.datasets import mnist
from keras.layers import Dense, Conv2D, BatchNormalization, Activation
from keras.layers import AveragePooling2D, Input, GlobalAveragePooling2D
from keras.optimizers import Adam
from keras.regularizers import l2
from keras import backend as K
from keras.models import Model
from keras.layers.core import Lambda
K.set_learning_phase(1)
# Input image dimensions
img_rows, img_cols = 28, 28
# The data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)
# Noised data
x_train = x_train.astype('float32') / 255. + 0.5*np.random.random([x_train.shape[0], img_rows, img_cols, 1])
x_test = x_test.astype('float32') / 255. + 0.5*np.random.random([x_test.shape[0], img_rows, img_cols, 1])
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
def abs_backend(inputs):
    return K.abs(inputs)
def expand_dim_backend(inputs):
    return K.expand_dims(K.expand_dims(inputs,1),1)
def sign_backend(inputs):
    return K.sign(inputs)
def pad_backend(inputs, in_channels, out_channels):
    pad_dim = (out_channels - in_channels)//2
    return K.spatial_3d_padding(inputs, padding = ((0,0),(0,0),(pad_dim,pad_dim)))
# Residual Shrinakge Block
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                             downsample_strides=2):
    
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
    
    for i in range(nb_blocks):
        
        identity = residual
        
        if not downsample:
            downsample_strides = 1
        
        residual = BatchNormalization()(residual)
        residual = Activation('relu')(residual)
        residual = Conv2D(out_channels, 3, strides=(downsample_strides, downsample_strides), 
                          padding='same', kernel_initializer='he_normal', 
                          kernel_regularizer=l2(1e-4))(residual)
        
        residual = BatchNormalization()(residual)
        residual = Activation('relu')(residual)
        residual = Conv2D(out_channels, 3, padding='same', kernel_initializer='he_normal', 
                          kernel_regularizer=l2(1e-4))(residual)
        
        # Calculate global means
        residual_abs = Lambda(abs_backend)(residual)
        abs_mean = GlobalAveragePooling2D()(residual_abs)
        
        # Calculate scaling coefficients
        scales = Dense(out_channels, activation=None, kernel_initializer='he_normal', 
                       kernel_regularizer=l2(1e-4))(abs_mean)
        scales = BatchNormalization()(scales)
        scales = Activation('relu')(scales)
        scales = Dense(out_channels, activation='sigmoid', kernel_regularizer=l2(1e-4))(scales)
        scales = Lambda(expand_dim_backend)(scales)
        
        # Calculate thresholds
        thres = keras.layers.multiply([abs_mean, scales])
        
        # Soft thresholding
        sub = keras.layers.subtract([residual_abs, thres])
        zeros = keras.layers.subtract([sub, sub])
        n_sub = keras.layers.maximum([sub, zeros])
        residual = keras.layers.multiply([Lambda(sign_backend)(residual), n_sub])
        
        # Downsampling (it is important to use the pooL-size of (1, 1))
        if downsample_strides > 1:
            identity = AveragePooling2D(pool_size=(1,1), strides=(2,2))(identity)
            
        # Zero_padding to match channels (it is important to use zero padding rather than 1by1 convolution)
        if in_channels != out_channels:
            identity = Lambda(pad_backend)(identity, in_channels, out_channels)
        
        residual = keras.layers.add([residual, identity])
    
    return residual
# define and train a model
inputs = Input(shape=input_shape)
net = Conv2D(8, 3, padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(inputs)
net = residual_shrinkage_block(net, 1, 8, downsample=True)
net = BatchNormalization()(net)
net = Activation('relu')(net)
net = GlobalAveragePooling2D()(net)
outputs = Dense(10, activation='softmax', kernel_initializer='he_normal', kernel_regularizer=l2(1e-4))(net)
model = Model(inputs=inputs, outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
model.fit(x_train, y_train, batch_size=100, epochs=5, verbose=1, validation_data=(x_test, y_test))
# get results
K.set_learning_phase(0)
DRSN_train_score = model.evaluate(x_train, y_train, batch_size=100, verbose=0)
print('Train loss:', DRSN_train_score[0])
print('Train accuracy:', DRSN_train_score[1])
DRSN_test_score = model.evaluate(x_test, y_test, batch_size=100, verbose=0)
print('Test loss:', DRSN_test_score[0])
print('Test accuracy:', DRSN_test_score[1])

下面是TFLearn程式:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 23 21:23:09 2019
Implemented using TensorFlow 1.0 and TFLearn 0.3.2
 
M. Zhao, S. Zhong, X. Fu, B. Tang, M. Pecht, Deep Residual Shrinkage Networks for Fault Diagnosis, 
IEEE Transactions on Industrial Informatics, 2019, DOI: 10.1109/TII.2019.2943898
 
@author: super_9527
"""
  
from __future__ import division, print_function, absolute_import
  
import tflearn
import numpy as np
import tensorflow as tf
from tflearn.layers.conv import conv_2d
  
# Data loading
from tflearn.datasets import cifar10
(X, Y), (testX, testY) = cifar10.load_data()
  
# Add noise
X = X + np.random.random((50000, 32, 32, 3))*0.1
testX = testX + np.random.random((10000, 32, 32, 3))*0.1
  
# Transform labels to one-hot format
Y = tflearn.data_utils.to_categorical(Y,10)
testY = tflearn.data_utils.to_categorical(testY,10)
  
def residual_shrinkage_block(incoming, nb_blocks, out_channels, downsample=False,
                   downsample_strides=2, activation='relu', batch_norm=True,
                   bias=True, weights_init='variance_scaling',
                   bias_init='zeros', regularizer='L2', weight_decay=0.0001,
                   trainable=True, restore=True, reuse=False, scope=None,
                   name="ResidualBlock"):
      
    # residual shrinkage blocks with channel-wise thresholds
  
    residual = incoming
    in_channels = incoming.get_shape().as_list()[-1]
  
    # Variable Scope fix for older TF
    try:
        vscope = tf.variable_scope(scope, default_name=name, values=[incoming],
                                   reuse=reuse)
    except Exception:
        vscope = tf.variable_op_scope([incoming], scope, name, reuse=reuse)
  
    with vscope as scope:
        name = scope.name #TODO
  
        for i in range(nb_blocks):
  
            identity = residual
  
            if not downsample:
                downsample_strides = 1
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3,
                             downsample_strides, 'same', 'linear',
                             bias, weights_init, bias_init,
                             regularizer, weight_decay, trainable,
                             restore)
  
            if batch_norm:
                residual = tflearn.batch_normalization(residual)
            residual = tflearn.activation(residual, activation)
            residual = conv_2d(residual, out_channels, 3, 1, 'same',
                             'linear', bias, weights_init,
                             bias_init, regularizer, weight_decay,
                             trainable, restore)
              
            # get thresholds and apply thresholding
            abs_mean = tf.reduce_mean(tf.reduce_mean(tf.abs(residual),axis=2,keep_dims=True),axis=1,keep_dims=True)
            scales = tflearn.fully_connected(abs_mean, out_channels//4, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tflearn.batch_normalization(scales)
            scales = tflearn.activation(scales, 'relu')
            scales = tflearn.fully_connected(scales, out_channels, activation='linear',regularizer='L2',weight_decay=0.0001,weights_init='variance_scaling')
            scales = tf.expand_dims(tf.expand_dims(scales,axis=1),axis=1)
            thres = tf.multiply(abs_mean,tflearn.activations.sigmoid(scales))
            # soft thresholding
            residual = tf.multiply(tf.sign(residual), tf.maximum(tf.abs(residual)-thres,0))
              
  
            # Downsampling
            if downsample_strides > 1:
                identity = tflearn.avg_pool_2d(identity, 1,
                                               downsample_strides)
  
            # Projection to new dimension
            if in_channels != out_channels:
                if (out_channels - in_channels) % 2 == 0:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch]])
                else:
                    ch = (out_channels - in_channels)//2
                    identity = tf.pad(identity,
                                      [[0, 0], [0, 0], [0, 0], [ch, ch+1]])
                in_channels = out_channels
  
            residual = residual + identity
  
    return residual
  
  
# Real-time data preprocessing
img_prep = tflearn.ImagePreprocessing()
img_prep.add_featurewise_zero_center(per_channel=True)
  
# Real-time data augmentation
img_aug = tflearn.ImageAugmentation()
img_aug.add_random_flip_leftright()
img_aug.add_random_crop([32, 32], padding=4)
  
# Build a Deep Residual Shrinkage Network with 3 blocks
net = tflearn.input_data(shape=[None, 32, 32, 3],
                         data_preprocessing=img_prep,
                         data_augmentation=img_aug)
net = tflearn.conv_2d(net, 16, 3, regularizer='L2', weight_decay=0.0001)
net = residual_shrinkage_block(net, 1, 16)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = residual_shrinkage_block(net, 1, 32, downsample=True)
net = tflearn.batch_normalization(net)
net = tflearn.activation(net, 'relu')
net = tflearn.global_avg_pool(net)
# Regression
net = tflearn.fully_connected(net, 10, activation='softmax')
mom = tflearn.Momentum(0.1, lr_decay=0.1, decay_step=20000, staircase=True)
net = tflearn.regression(net, optimizer=mom, loss='categorical_crossentropy')
# Training
model = tflearn.DNN(net, checkpoint_path='model_cifar10',
                    max_checkpoints=10, tensorboard_verbose=0,
                    clip_gradients=0.)
  
model.fit(X, Y, n_epoch=100, snapshot_epoch=False, snapshot_step=500,
          show_metric=True, batch_size=100, shuffle=True, run_id='model_cifar10')
  
training_acc = model.evaluate(X, Y)[0]
validation_acc = model.evaluate(testX, testY)[0]

論文網址

M. Zhao, S. Zhong, X. Fu, et al., Deep residual shrinkage networks for fault diagnosis, IEEE Transactions on Industrial Informatics, DOI: 10.1109/TII.2019.2943898

https://ieeexplore.ieee.org/document/8850096


來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69960177/viewspace-2674313/,如需轉載,請註明出處,否則將追究法律責任。

相關文章