CBAM: Convolutional Block Attention Module์ ECCV 2018์์ ๋ฐํ๋ channel&spatial attention module์ด๋ค. ์ฝ๋๋ ๊ณต์ github์ ์ฐธ๊ณ ํ์ฌ ์กฐ๊ธ ์์ ํ๋ค.
Paper: https://arxiv.org/pdf/1807.06521.pdf
Code: https://github.com/Jongchan/attention-module/
Author's blog: https://blog.lunit.io/2018/08/30/bam-and-cbam-self-attention-modules-for-cnn/
BAM
CBAM์ BAM: Bottleneck Attention Module์ ํ์ ๋ ผ๋ฌธ์ด๋ค. BAM์ ๋ชจ๋ธ์ bottleneck ๋ถ๋ถ์์ attention์ ์งํํ๋ ๋ฐฉ์์ผ๋ก, channel attention๊ณผ spatial attention์ ๋ณ๋ ฌ์ ์ผ๋ก ๊ณ์ฐํ๋ค.
BAM์์๋ channel attention๊ณผ spatial attention์ ๋๋์ด ๊ณ์ฐํ๊ณ , ๊ฐ output์ ๋ํด (+sigmoid) input size์ ๋์ผํ ํฌ๊ธฐ์ attention map์ ์์ฑํ๋ค. ์ด๋ 3D attention map์ ๋ ๊ฐ์ง ์ถ์ผ๋ก decomposeํ๋ค๊ณ ๋ณผ ์ ์๋ค.
CBAM
BAM์์๋ channel๊ณผ spatial attention์ ๋ฐ๋ก ๊ตฌํด ํฉ์ณ์ final attention map์ ๋ง๋ค์์ง๋ง, CBAM์์๋ channel attention์ ๋จผ์ ์ ์ฉํ ํ spatial attention์ ์ ์ฉํ๋, ์์ฐจ์ ์ธ ๋ฐฉ์์ ์ด์ฉํ๋ค.
์ฃผ์ด์ง input image์ ๋ํ์ฌ, channel attention module๊ณผ spatial attention module์ 'what'๊ณผ 'where'์ ์ง์คํ๋ ์ํธ๋ณด์์ ์ธ attention์ ๊ณ์ฐํ๊ฒ ๋๋ค.
class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio, channel_attention=True, spatial_attention=True):
super(CBAM, self).__init__()
self.channel_attention, self.spatial_attention = channel_attention, spatial_attention
if channel_attention:
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio)
if spatial_attention:
self.SpatialGate = SpatialGate()
def forward(self, x):
if self.channel_attention:
x = self.ChannelGate(x)
if self.spatial_attention:
x = self.SpatialGate(x)
return x
Channel attention
BAM์ average pooling์ ์ด์ฉํ์ง๋ง, CBAM์ average pooling, max pooling ๋ ๊ฐ์ง๋ฅผ ๊ฒฐํฉํด์ ์ฌ์ฉํ๋ค. ๋ pooled feature์ ๊ฐ์ ์๋ฏธ๋ฅผ ๊ณต์ ํ๋ ๊ฐ์ด๊ธฐ ๋๋ฌธ์, ํ๋์ shared MLP๋ฅผ ์ฌ์ฉํ ์ ์๋ค. (less # params)
๋ attention map์ ๋ํ์ฌ (+sigmoid) channel attention map์ ์์ฑํ๋ค.
class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ChannelGate(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16):
super(ChannelGate, self).__init__()
self.gate_channels = gate_channels
self.mlp = nn.Sequential(
Flatten(),
nn.Linear(gate_channels, gate_channels // reduction_ratio),
nn.ReLU(),
nn.Linear(gate_channels // reduction_ratio, gate_channels),
nn.Sigmoid()
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.maxpool = nn.AdaptiveMaxPool2d((1, 1))
def forward(self, x):
x_avg_pool = self.mlp(self.avgpool(x))
x_max_pool = self.mlp(self.maxpool(x))
attention = x_avg_pool + x_max_pool
attention = attention.unsqueeze(2).unsqueeze(3).expand_as(x)
return x*attention
Spatial attention
Spatial attention ์ญ์ channel attention๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก, channel์ ์ถ์ผ๋ก max pooling๊ณผ average pooling์ ์ ์ฉํด ์์ฑํ 1xHxW์ ๋ feature map์ concatํ๊ณ , ์ฌ๊ธฐ์ 7x7 conv๋ฅผ ์ ์ฉํ์ฌ (+sigmoid) spatial attention map์ ์์ฑํ๋ค.
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(2, 1, 7, padding=3),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
def forward(self, x):
x_avg_pool = torch.mean(x,1).unsqueeze(1)
x_max_pool = torch.max(x,1)[0].unsqueeze(1)
attention = torch.cat((x_avg_pool, x_max_pool), dim=1)
attention = self.conv(attention)
return x*attention
Results
๋ ผ๋ฌธ์์ reportํ ์คํ ๊ฒฐ๊ณผ๋ค์ด๋ค.
Classification:
Detection:
Ablation studies
Pooling methods of channel attention:
Pooling methods of spatial attention:
How to combine both channel and spatial attention modules:
- ๊ฐ module์ด ์๋ก ๋ค๋ฅธ ๊ธฐ๋ฅ์ ํ๊ธฐ ๋๋ฌธ์, ๊ทธ ์์๊ฐ ์ ์ฒด ์ฑ๋ฅ์ ์ํฅ์ ๋ฏธ์น ์ ์๋ค.
'๐ Deep Learning > Implementation' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[PyTorch Implementation] StyleGAN2 (2) | 2022.09.26 |
---|---|
[PyTorch Implementation] PointNet ์ค๋ช ๊ณผ ์ฝ๋ (0) | 2022.08.12 |
[PyTorch Implementation] ResNet-B, ResNet-C, ResNet-D, ResNet Tweaks (1) | 2022.06.05 |
[PyTorch Implementation] PyTorch๋ก ๊ตฌํํ cycleGAN์ loss ๋ถ๋ถ ์ค๋ช (0) | 2021.08.04 |
[PyTorch Implementation] 3D Segmentation model - VoxResNet, Attention U-Net, V-Net (0) | 2020.12.30 |