๐ŸŒŒ Deep Learning/Implementation

[PyTorch Implementation] CBAM: Convolutional Block Attention Module ์„ค๋ช… + ์ฝ”๋“œ

๋ณต๋งŒ 2022. 4. 22. 15:40

CBAM: Convolutional Block Attention Module์€ ECCV 2018์—์„œ ๋ฐœํ‘œ๋œ channel&spatial attention module์ด๋‹ค. ์ฝ”๋“œ๋Š” ๊ณต์‹ github์„ ์ฐธ๊ณ ํ•˜์—ฌ ์กฐ๊ธˆ ์ˆ˜์ •ํ–ˆ๋‹ค.

 

 

Paper: https://arxiv.org/pdf/1807.06521.pdf

Code: https://github.com/Jongchan/attention-module/

Author's blog: https://blog.lunit.io/2018/08/30/bam-and-cbam-self-attention-modules-for-cnn/

 

 

 

BAM

 

CBAM์€ BAM: Bottleneck Attention Module์˜ ํ›„์† ๋…ผ๋ฌธ์ด๋‹ค. BAM์€ ๋ชจ๋ธ์˜ bottleneck ๋ถ€๋ถ„์—์„œ attention์„ ์ง„ํ–‰ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ, channel attention๊ณผ spatial attention์„ ๋ณ‘๋ ฌ์ ์œผ๋กœ ๊ณ„์‚ฐํ•œ๋‹ค. 

 

 

BAM์—์„œ๋Š” channel attention๊ณผ spatial attention์„ ๋‚˜๋ˆ„์–ด ๊ณ„์‚ฐํ•˜๊ณ , ๊ฐ output์„ ๋”ํ•ด (+sigmoid) input size์™€ ๋™์ผํ•œ ํฌ๊ธฐ์˜ attention map์„ ์ƒ์„ฑํ•œ๋‹ค. ์ด๋Š” 3D attention map์„ ๋‘ ๊ฐ€์ง€ ์ถ•์œผ๋กœ decomposeํ–ˆ๋‹ค๊ณ  ๋ณผ ์ˆ˜ ์žˆ๋‹ค.

 

 

 

CBAM

BAM์—์„œ๋Š” channel๊ณผ spatial attention์„ ๋”ฐ๋กœ ๊ตฌํ•ด ํ•ฉ์ณ์„œ final attention map์„ ๋งŒ๋“ค์—ˆ์ง€๋งŒ, CBAM์—์„œ๋Š” channel attention์„ ๋จผ์ € ์ ์šฉํ•œ ํ›„ spatial attention์„ ์ ์šฉํ•˜๋Š”, ์ˆœ์ฐจ์ ์ธ ๋ฐฉ์‹์„ ์ด์šฉํ–ˆ๋‹ค.

์ฃผ์–ด์ง„ input image์— ๋Œ€ํ•˜์—ฌ, channel attention module๊ณผ spatial attention module์€ 'what'๊ณผ 'where'์— ์ง‘์ค‘ํ•˜๋Š” ์ƒํ˜ธ๋ณด์™„์ ์ธ attention์„ ๊ณ„์‚ฐํ•˜๊ฒŒ ๋œ๋‹ค.

 

 

class CBAM(nn.Module):
	def __init__(self, gate_channels, reduction_ratio, channel_attention=True, spatial_attention=True):
		super(CBAM, self).__init__()
		self.channel_attention, self.spatial_attention = channel_attention, spatial_attention
		if channel_attention:
			self.ChannelGate = ChannelGate(gate_channels, reduction_ratio)
		if spatial_attention:
			self.SpatialGate = SpatialGate()

	def forward(self, x):
		if self.channel_attention:
			x = self.ChannelGate(x)
		if self.spatial_attention:
			x = self.SpatialGate(x)
		return x

 

 

Channel attention

BAM์€ average pooling์„ ์ด์šฉํ–ˆ์ง€๋งŒ, CBAM์€ average pooling, max pooling ๋‘ ๊ฐ€์ง€๋ฅผ ๊ฒฐํ•ฉํ•ด์„œ ์‚ฌ์šฉํ•œ๋‹ค. ๋‘ pooled feature์€ ๊ฐ™์€ ์˜๋ฏธ๋ฅผ ๊ณต์œ ํ•˜๋Š” ๊ฐ’์ด๊ธฐ ๋•Œ๋ฌธ์—, ํ•˜๋‚˜์˜ shared MLP๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. (less # params)

 

๋‘ attention map์„ ๋”ํ•˜์—ฌ (+sigmoid) channel attention map์„ ์ƒ์„ฑํ•œ๋‹ค.

class Flatten(nn.Module):
	def forward(self, x):
		return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
	def __init__(self, gate_channels, reduction_ratio=16):
		super(ChannelGate, self).__init__()
		self.gate_channels = gate_channels
		self.mlp = nn.Sequential(
			Flatten(),
			nn.Linear(gate_channels, gate_channels // reduction_ratio),
			nn.ReLU(),
			nn.Linear(gate_channels // reduction_ratio, gate_channels),
			nn.Sigmoid()
			)
		self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
		self.maxpool = nn.AdaptiveMaxPool2d((1, 1))

	def forward(self, x):
		x_avg_pool = self.mlp(self.avgpool(x))
		x_max_pool = self.mlp(self.maxpool(x))
		attention = x_avg_pool + x_max_pool
		attention = attention.unsqueeze(2).unsqueeze(3).expand_as(x)
		return x*attention

 

Spatial attention

Spatial attention ์—ญ์‹œ channel attention๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ, channel์„ ์ถ•์œผ๋กœ max pooling๊ณผ average pooling์„ ์ ์šฉํ•ด ์ƒ์„ฑํ•œ 1xHxW์˜ ๋‘ feature map์„ concatํ•˜๊ณ , ์—ฌ๊ธฐ์— 7x7 conv๋ฅผ ์ ์šฉํ•˜์—ฌ (+sigmoid) spatial attention map์„ ์ƒ์„ฑํ•œ๋‹ค.

class SpatialGate(nn.Module):
	def __init__(self):
		super(SpatialGate, self).__init__()
		self.conv = nn.Sequential(
			nn.Conv2d(2, 1, 7, padding=3),
			nn.BatchNorm2d(1),
			nn.Sigmoid()
			)
            
	def forward(self, x):
		x_avg_pool = torch.mean(x,1).unsqueeze(1)
		x_max_pool = torch.max(x,1)[0].unsqueeze(1)
		attention = torch.cat((x_avg_pool, x_max_pool), dim=1)
		attention = self.conv(attention)
		return x*attention

 

 

Results

๋…ผ๋ฌธ์—์„œ reportํ•œ ์‹คํ—˜ ๊ฒฐ๊ณผ๋“ค์ด๋‹ค. 

 

Classification:

 

Detection:

 

Ablation studies

Pooling methods of channel attention:

 

Pooling methods of spatial attention:

 

How to combine both channel and spatial attention modules:

 

- ๊ฐ module์ด ์„œ๋กœ ๋‹ค๋ฅธ ๊ธฐ๋Šฅ์„ ํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ๊ทธ ์ˆœ์„œ๊ฐ€ ์ „์ฒด ์„ฑ๋Šฅ์— ์˜ํ–ฅ์„ ๋ฏธ์น  ์ˆ˜ ์žˆ๋‹ค. 

๋ฐ˜์‘ํ˜•