學習筆記16:殘差網路

有何m不可發表於2024-06-04

轉自:https://www.cnblogs.com/miraclepbc/p/14368116.html

產生背景

隨著網路深度的增加,會出現網路退化的現象。
網路退化現象形象化解釋是在訓練集上的loss不增反降。
這說明,淺層網路的訓練效果要好於深層網路
一個想法就是,如果將淺層網路的特徵傳到深層網路,那麼深層網路的訓練效果不會比淺層網路差
舉個例子,就是假設總共有50層,20層的訓練結果就比50層的好了,因此可以將18層與98層之間連線一個直接對映
這樣隨著網路的加深,訓練效果就不會降低了

殘差塊

殘差塊的數學表示:
xl+1=xl+F(xl,Wl)𝑥𝑙+1=𝑥𝑙+𝐹(𝑥𝑙,𝑊𝑙)
xl𝑥𝑙相當於是一個直接對映,F(xl,Wl)𝐹(𝑥𝑙,𝑊𝑙)是殘差部分
學習筆記16:殘差網路

在這個網路結構中,右側指的就是殘差部分,左側是直接對映

程式碼實現

class ResnetbasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channnels, out_channels, kernel_size = 3, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = F.relu(self.bn1(out), inplace = True)
        out = self.conv2(out)
        out = F.relu(self.bn2(out), inplace = True)
        out = out + residual
        return F.relu(out)

相關文章