paper
`
import torch
import torch.nn as nn
def build_act_layer(act_type):
#Build activation layer
if act_type is None:
return nn.Identity()
assert act_type in ['GELU', 'ReLU', 'SiLU']
if act_type == 'SiLU':
return nn.SiLU()
elif act_type == 'ReLU':
return nn.ReLU()
else:
return nn.GELU()
class ElementScale(nn.Module):
#A learnable element-wise scaler.
def __init__(self, embed_dims, init_value=0., requires_grad=True):
super(ElementScale, self).__init__()
self.scale = nn.Parameter(
init_value * torch.ones((1, embed_dims, 1, 1)),
requires_grad=requires_grad
)
def forward(self, x):
return x * self.scale
class ChannelAggregationFFN(nn.Module):
"""An implementation of FFN with Channel Aggregation.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`.
feedforward_channels (int): The hidden dimension of FFNs.
kernel_size (int): The depth-wise conv kernel size as the
depth-wise convolution. Defaults to 3.
act_type (str): The type of activation. Defaults to 'GELU'.
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
"""
def __init__(self,
embed_dims,
kernel_size=3,
act_type='GELU',
ffn_drop=0.):
super(ChannelAggregationFFN, self).__init__()
self.embed_dims = embed_dims
self.feedforward_channels = int(embed_dims * 4)
self.fc1 = nn.Conv2d(
in_channels=embed_dims,
out_channels=self.feedforward_channels,
kernel_size=1)
self.dwconv = nn.Conv2d(
in_channels=self.feedforward_channels,
out_channels=self.feedforward_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
bias=True,
groups=self.feedforward_channels)
self.act = build_act_layer(act_type)
self.fc2 = nn.Conv2d(
in_channels=self.feedforward_channels,
out_channels=embed_dims,
kernel_size=1)
self.drop = nn.Dropout(ffn_drop)
self.decompose = nn.Conv2d(
in_channels=self.feedforward_channels, # C -> 1
out_channels=1, kernel_size=1,
)
self.sigma = ElementScale(
self.feedforward_channels, init_value=1e-5, requires_grad=True)
self.decompose_act = build_act_layer(act_type)
def feat_decompose(self, x):
# x_d: [B, C, H, W] -> [B, 1, H, W]
t=self.decompose(x) # 將多通道用一個通道來表示
t=self.decompose_act(t) # 對單通道應用GELU啟用函式 增加非線性 幫助模型學習更加複雜的模式
t=x - t #原始特徵圖減去t 去除或削弱了x與temp相似的特徵 如果t是全域性或主要特徵 那麼x-t可以理解成區域性或差異資訊
t=self.sigma(t) # 包含一個可學習的引數 用來調整每個通道的權重
x = x + t # 將原特徵圖和縮放之後的特定資訊相加
return x
def forward(self, x):
# proj 1
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
# proj 2
x = self.feat_decompose(x)
x = self.fc2(x)
x = self.drop(x)
return x
if name == 'main':
input = torch.randn(1, 64, 32, 32).cuda()# 輸入 B C H W
block = ChannelAggregationFFN(embed_dims=64).cuda()
output = block(input)
print(input.size())
print(output.size())
`