๐Ÿ Python & library/PyTorch

[PyTorch] nn.Conv์˜ padding๊ณผ padding_mode

๋ณต๋งŒ 2022. 3. 24. 14:16

PyTorch ์—์„œ ์ œ๊ณตํ•˜๋Š” convolution ํ•จ์ˆ˜์— ์„ค์ • ๊ฐ€๋Šฅํ•œ parameter ์ค‘padding๊ณผ padding_mode๋ผ๋Š” ๊ฒƒ์ด ์žˆ๋‹ค.

 

 

padding์˜ ๊ฒฝ์šฐ padding์˜ ํฌ๊ธฐ๋ฅผ ์ง€์ •ํ•  ์ˆ˜ ์žˆ๋Š” parameter์ธ๋ฐ (int ํ˜น์€ tuple), PyTorch 1.9.0๋ถ€ํ„ฐ string์œผ๋กœ ์ง€์ •ํ•  ์ˆ˜ ์žˆ๋Š” ์˜ต์…˜์ด ์ถ”๊ฐ€๋˜์—ˆ๋‹ค.

์ด๋Š” Tensorflow์—์„œ๋Š” ์›๋ž˜ ์žˆ๋˜ ์˜ต์…˜์ธ๋ฐ, padding์˜ ํฌ๊ธฐ๋ฅผ ์ง์ ‘ ์ง€์ •ํ•˜๋Š” ๋Œ€์‹  same ํ˜น์€ valid ์˜ต์…˜์„ ์ฃผ๋ฉด input size์— ๋งž๊ฒŒ ์ž๋™์œผ๋กœ padding ํฌ๊ธฐ๊ฐ€ ์„ค์ •๋œ๋‹ค.

  • valid๋Š” padding์„ ๋”ฐ๋กœ ์ฃผ์ง€ ์•Š๊ณ  input image ์ž์ฒด๋งŒ์„ ์ด์šฉํ•ด convolution ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.
  • same์€ output size๊ฐ€ input size์™€ ๋™์ผํ•˜๊ฒŒ ๋˜๋„๋ก padding์„ ์กฐ์ ˆํ•œ๋‹ค. ๋งŒ์•ฝ stride=1, dilation=1์ธ ๊ฒฝ์šฐ padding=(kernel_size-1)/2๋กœ ์„ค์ •๋œ๋‹ค.

 


 

padding_mode๋Š” padding์„ ๋ญ˜๋กœ ์ฑ„์šธ ์ง€ ์„ค์ •ํ•  ์ˆ˜ ์žˆ๋‹ค. zeros, reflect, replicate, circular์ด ์žˆ์œผ๋ฉฐ, ๊ธฐ๋ณธ ๊ฐ’์€ zeros๋กœ, padding์„ ๋ชจ๋‘ 0์œผ๋กœ ์ฑ„์šด๋‹ค. ๋Œ€๋ถ€๋ถ„ zero-filling์„ ์‚ฌ์šฉํ•˜์ง€๋งŒ ๋‹ค์–‘ํ•œ ์„ ํƒ์ง€๊ฐ€ ์žˆ์–ด ์†Œ๊ฐœํ•ด ๋ณด๊ณ ์ž ํ•œ๋‹ค.

 

* ์˜ˆ์‹œ๋ฅผ ์œ„ํ•ด [1, 2, 3, 4, 5] ๋ชจ์–‘์˜ 1D tensor์™€ padding์„ 4๋กœ ์„ค์ •ํ•œ 1D identity Conv ์—ฐ์‚ฐ์„ ์ด์šฉํ–ˆ๋‹ค.

 

  • zeros: zero-filling์„ ์ด์šฉํ•œ๋‹ค.
x = torch.tensor([[[1, 2, 3, 4, 5]]]).float()

conv = nn.Conv1d(1, 1, 3, padding=4, padding_mode='zeros', bias=False)
conv.weight = torch.nn.Parameter(torch.tensor([[[0., 1., 0.]]]))

y = conv(x)
print(y)
>> tensor([[[0., 0., 0., 1., 2., 3., 4., 5., 0., 0., 0.]]])

 

  • reflect: ์–‘ ๋์— ๊ฑฐ์šธ์ฒ˜๋Ÿผ ๋ฐ˜์‚ฌ๋œ ๊ฐ’์„ ์‚ฌ์šฉํ•œ๋‹ค. ๋‹จ, ์ด ๊ฒฝ์šฐ input ํฌ๊ธฐ๋ณด๋‹ค ๋” ํฐ ๊ฐ’์˜ padding์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†๋‹ค.
x = torch.tensor([[[1, 2, 3, 4, 5]]]).float()

conv = nn.Conv1d(1, 1, 3, padding=4, padding_mode='reflect', bias=False)
conv.weight = torch.nn.Parameter(torch.tensor([[[0., 1., 0.]]]))

y = conv(x)
print(y)
>> tensor([[[4., 3., 2., 1., 2., 3., 4., 5., 4., 3., 2.]]])

 

  • replicate: ์–‘ ๋๋‹จ์˜ ๊ฐ’์„ padding ๊ฐ’์œผ๋กœ ์ด์šฉํ•œ๋‹ค.
import torch
import torch.nn as nn

x = torch.tensor([[[1, 2, 3, 4, 5]]]).float()

conv = nn.Conv1d(1, 1, 3, padding=4, padding_mode='replicate', bias=False)
conv.weight = torch.nn.Parameter(torch.tensor([[[0., 1., 0.]]]))

y = conv(x)
print(y)
>> tensor([[[1., 1., 1., 1., 2., 3., 4., 5., 5., 5., 5.]]])

 

  • circular: input ๊ฐ’์„ ์ˆœํ™˜ํ•˜์—ฌ ์‚ฌ์šฉํ•œ๋‹ค.
import torch
import torch.nn as nn

x = torch.tensor([[[1, 2, 3, 4, 5]]]).float()

conv = nn.Conv1d(1, 1, 3, padding=4, padding_mode='circular', bias=False)
conv.weight = torch.nn.Parameter(torch.tensor([[[0., 1., 0.]]]))

y = conv(x)
print(y)
>> tensor([[[3., 4., 5., 1., 2., 3., 4., 5., 1., 2., 3.]]])

 

๋ฐ˜์‘ํ˜•