MOGANET-CA模組

iceeci發表於2024-11-09

paper
`

import torch
import torch.nn as nn

def build_act_layer(act_type):
#Build activation layer
if act_type is None:
return nn.Identity()
assert act_type in ['GELU', 'ReLU', 'SiLU']
if act_type == 'SiLU':
return nn.SiLU()
elif act_type == 'ReLU':
return nn.ReLU()
else:
return nn.GELU()

class ElementScale(nn.Module):
#A learnable element-wise scaler.

def __init__(self, embed_dims, init_value=0., requires_grad=True):
    super(ElementScale, self).__init__()
    self.scale = nn.Parameter(
        init_value * torch.ones((1, embed_dims, 1, 1)),
        requires_grad=requires_grad
    )

def forward(self, x):
    return x * self.scale

class ChannelAggregationFFN(nn.Module):
"""An implementation of FFN with Channel Aggregation.

Args:
    embed_dims (int): The feature dimension. Same as
        `MultiheadAttention`.
    feedforward_channels (int): The hidden dimension of FFNs.
    kernel_size (int): The depth-wise conv kernel size as the
        depth-wise convolution. Defaults to 3.
    act_type (str): The type of activation. Defaults to 'GELU'.
    ffn_drop (float, optional): Probability of an element to be
        zeroed in FFN. Default 0.0.
"""

def __init__(self,
             embed_dims,
             kernel_size=3,
             act_type='GELU',
             ffn_drop=0.):
    super(ChannelAggregationFFN, self).__init__()

    self.embed_dims = embed_dims
    self.feedforward_channels = int(embed_dims * 4)

    self.fc1 = nn.Conv2d(
        in_channels=embed_dims,
        out_channels=self.feedforward_channels,
        kernel_size=1)
    self.dwconv = nn.Conv2d(
        in_channels=self.feedforward_channels,
        out_channels=self.feedforward_channels,
        kernel_size=kernel_size,
        stride=1,
        padding=kernel_size // 2,
        bias=True,
        groups=self.feedforward_channels)
    self.act = build_act_layer(act_type)
    self.fc2 = nn.Conv2d(
        in_channels=self.feedforward_channels,
        out_channels=embed_dims,
        kernel_size=1)
    self.drop = nn.Dropout(ffn_drop)

    self.decompose = nn.Conv2d(
        in_channels=self.feedforward_channels,  # C -> 1
        out_channels=1, kernel_size=1,
    )
    self.sigma = ElementScale(
        self.feedforward_channels, init_value=1e-5, requires_grad=True)
    self.decompose_act = build_act_layer(act_type)

def feat_decompose(self, x):
    # x_d: [B, C, H, W] -> [B, 1, H, W]
    t=self.decompose(x)  #  將多通道用一個通道來表示
    t=self.decompose_act(t) # 對單通道應用GELU啟用函式 增加非線性  幫助模型學習更加複雜的模式
    t=x - t #原始特徵圖減去t  去除或削弱了x與temp相似的特徵  如果t是全域性或主要特徵 那麼x-t可以理解成區域性或差異資訊
    t=self.sigma(t) # 包含一個可學習的引數 用來調整每個通道的權重
    x = x + t  # 將原特徵圖和縮放之後的特定資訊相加
    return x

def forward(self, x):
    # proj 1
    x = self.fc1(x)
    x = self.dwconv(x)
    x = self.act(x)
    x = self.drop(x)
    # proj 2
    x = self.feat_decompose(x)
    x = self.fc2(x)
    x = self.drop(x)
    return x

if name == 'main':
input = torch.randn(1, 64, 32, 32).cuda()# 輸入 B C H W
block = ChannelAggregationFFN(embed_dims=64).cuda()
output = block(input)
print(input.size())
print(output.size())

`

相關文章