EMCAD: Efficient Multi-scale Convolutional Attention Decoding for Medical Image Segmentation-LAGA

iceeci發表於2024-11-09

論文
程式碼
`
import torch
import torch.nn as nn
from functools import partial
from torch.nn.init import trunc_normal_
import math
from timm.models.helpers import named_apply

def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
# activation layer
act = act.lower()
if act == 'relu':
layer = nn.ReLU(inplace)
elif act == 'relu6':
layer = nn.ReLU6(inplace)
elif act == 'leakyrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
elif act == 'gelu':
layer = nn.GELU()
elif act == 'hswish':
layer = nn.Hardswish(inplace)
else:
raise NotImplementedError('activation layer [%s] is not found' % act)
return layer

def init_weights(module, name, scheme=''):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d):
if scheme == 'normal':
nn.init.normal
(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'trunc_normal':
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'xavier_normal':
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif scheme == 'kaiming_normal':
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
# efficientnet like
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
fan_out //= module.groups
nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)

Large-kernel grouped attention gate (LGAG)

class LGAG(nn.Module):
'''
結合特徵圖與注意力係數,啟用高相關性特徵,高層特徵的門控訊號來控制網路不通階段間的資訊流動
在LAGA機制中,能夠有效地融合來自skip連結的資訊,以更少的計算在更大的區域性上下文中捕獲顯著的特徵。
用在需要將兩個shape相同的tensor融合的地方。
'''
def init(self, F_g, F_l, F_int, kernel_size=3, groups=1, activation='relu'):
super(LGAG, self).init()

    if kernel_size == 1:
        groups = 1
    self.W_g = nn.Sequential(
        nn.Conv2d(F_g, F_int, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=groups,
                  bias=True),
        nn.BatchNorm2d(F_int)
    )
    self.W_x = nn.Sequential(
        nn.Conv2d(F_l, F_int, kernel_size=kernel_size, stride=1, padding=kernel_size // 2, groups=groups,
                  bias=True),
        nn.BatchNorm2d(F_int)
    )
    self.psi = nn.Sequential(
        nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
        nn.BatchNorm2d(1),
        nn.Sigmoid()
    )
    self.activation = act_layer(activation, inplace=True)

    self.init_weights('normal')

def init_weights(self, scheme=''):
    named_apply(partial(_init_weights, scheme=scheme), self)

def forward(self, g, x):
    g1 = self.W_g(g)
    x1 = self.W_x(x)
    psi = self.activation(g1 + x1)
    psi = self.psi(psi)

    return x * psi

if name == 'main':
in_dim=128
width=4
hidden_dim=in_dim//width

block = LGAG(in_dim,in_dim,hidden_dim).cuda()
g = torch.randn(3,128,64,64).cuda() #輸入 B C H W
x = torch.randn(3,128,64,64).cuda() #輸入 B C H W
output = block(g,x)

print(input.size())
print(output.size())

`

相關文章