[๋ฅ๋ฌ๋ ๋ ผ๋ฌธ๋ฆฌ๋ทฐ] AdamP: Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights (Naver AI Lab, ICLR 2021)
Adam์ ๋ฌธ์ ์ ์ ๊ทน๋ณตํ๋ ์๋ก์ด optimizer์ธ AdamP๋ฅผ ์ ์ํ๋ ๋ ผ๋ฌธ์ผ๋ก, Naver AI Lab & Naver Clova์์ ICLR 2021์ ๋ฐํํ์๋ค.
Paper: https://arxiv.org/pdf/2006.08217.pdf
Project page: https://clovaai.github.io/AdamP/
Code: https://github.com/clovaai/adamp
1. Adam์ ๋ฌธ์ ์
์์ฝ: Adam์ ๋น๋กฏํ momentum-based gradient descent optimzer๋ค์ ํ์ต ๋์ค weight norm์ ํฌ๊ฒ ์ฆ๊ฐ์ํจ๋ค.
๊ทธ ์์ธ์ ๋ค์๊ณผ ๊ฐ๋ค.
๋๋ถ๋ถ์ ๋ชจ๋ธ๋ค์์ Batch normalization ๋ฑ์ normalization ๊ธฐ๋ฒ๋ค์ ์ฌ์ฉํด weight๋ฅผ scale-invariantํ๊ฒ ๋ง๋ ๋ค. ๋ฐ๋ผ์ weight๋ค์ ํฌ๊ธฐ๋ ๋ชจ๋ธ์ ์ํฅ์ ๋ฏธ์น์ง ์๊ฒ ๋๋ค.
๋ฐ๋ผ์ ๋ชจ๋ธ์ ์ํฅ์ ๋ฏธ์น๋ ๊ฐ์, weight๋ค์ l2-norm์ผ๋ก ๋๋ ๊ฐ์ด๋ค. ์ด๋ค์ effective weight $\hat w=\frac{w}{||w||_2}$์ด๋ผ๊ณ ํ์.
๊ทธ๋ฌ๋ ์ค์ optimization์ด ์ผ์ด๋๋ ๊ณต๊ฐ์ effective weight์ด ์๋ ๊ณต๊ฐ์ด ์๋๋ผ, ์๋์ weight๊ฐ ๋์ฌ์๋ nominal space์ด๋ค.
์ด๋ฌํ ์ด์ ๋ก effective step size์ ์ค์ nominal step size ๊ฐ์ ์ฐจ์ด๊ฐ ๋ฐ์ํ๋ค. ์๋ ๊ทธ๋ฆผ์์ $w_t$๊ฐ $w_{t+1}$์ผ๋ก ์ ๋ฐ์ดํธ ๋๋ฉด, ์ค์ ๋ชจ๋ธ์ ์ํฅ์ ๋ฏธ์น๋ effective step์ ์ฃผํฉ์์ผ๋ก ํ์๋ ๋ถ๋ถ๊ณผ ๊ฐ๋ค.
Nominal step size์ effective step size๋ ์ฝ $\frac{1}{||w_{t+1}||_2}$๋งํผ ์ฐจ์ด๊ฐ ๋๊ฒ ๋๋ค.
์ผ๋ฐ์ ์ธ gradient descent (GD) ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ๋ฉด ํ์ต ๋์ค weight norm์ด ์ฆ๊ฐํ๋ ํ์์ด ๋ฐ์ํ๋๋ฐ, (Lemma 2.1) ๋ฐ๋ผ์ effective step size $\Delta \hat w_t$๊ฐ ์์ฐ์ค๋ฝ๊ฒ ์ค์ด๋ค๊ฒ ๋๋ฉฐ ์์ ์ ์ธ ํ์ต์ ๋๋๋ค.
๊ทธ๋ฌ๋ Momentum์ด ์ถ๊ฐ๋ Adam๊ณผ ๊ฐ์ optimizer์ ๊ฒฝ์ฐ, weight norm์ด ๋์ฑ ๋น ๋ฅด๊ฒ ์ฆ๊ฐํ๋ค (Lemma 2.2).
๊ทธ ๊ฒฐ๊ณผ, ์๋ ๊ทธ๋ฆผ๊ณผ ๊ฐ์ด, GD๋ ์๋ ด์๋๊ฐ ๋๋ฆฌ์ง๋ง weight norm์ด ํฌ๊ฒ ์ฆ๊ฐํ์ง ์๋ ํํธ,
GD + momentum์ ์ ๋ฐ์ดํธ ์๋๊ฐ ๋น ๋ฅด์ง๋ง weight norm์ด ์ญ์ ๋งค์ฐ ํฌ๊ฒ ์ฆ๊ฐํ์ฌ effective step size๋ ๊ทธ๋ฆฌ ํฌ์ง ์์ ๊ฒ์ ํ์ธํ ์ ์๋ค.
์ด๋ก ์ธํด effective convergence์ ์๋๊ฐ ๋งค์ฐ ๋๋ ค์ง๊ฑฐ๋, sub-optimalํ ํด๋ต์ ์ฐพ์ ์ ์๋ค๋ ๋ฌธ์ ๊ฐ ๋ฐ์ํ๊ฒ ๋๋ค.
2. Method (SGDP, AdamP)
์์์ Adam๊ณผ ๊ฐ์ momentum ๊ธฐ๋ฐ optimizer์ ๋ฌธ์ ๋ฅผ ๋ฐํ๋๋ค. ํ์ต๊ณผ์ ์์ weight norm์ ๋งค์ฐ ํฌ๊ฒ ์ฆ๊ฐ์ํค๋ ์ฑํฅ์ด ์๊ธฐ ๋๋ฌธ์, effective convergence์ ์๋๊ฐ ๋งค์ฐ ๋๋ ค์ง๊ณ , sub-optimalํ ํด๋ต์ ์ฐพ๊ฒ ๋๋ค๋ ๊ฒ์ด๋ค.
๋ณธ ๋ ผ๋ฌธ์์๋ ์ด๋ฅผ projection์ ์ด์ฉํ์ฌ, weight norm์ ์ฆ๊ฐ๋ฅผ ๋ง์ ์ ์๋ optimization ๋ฐฉ๋ฒ์ ์๊ฐํ๋ค. ์ด ๋ฐฉ๋ฒ์ effective space์์์ update direction์ ๊ฑด๋๋ฆฌ์ง ์๊ณ , effective step size๋ง์ ๋ณ๊ฒฝํ ์ ์๋ ๋ฐฉ๋ฒ์ด๋ค.
๊ฐ๋จํ ๋งํ๋ฉด, ๋งค update์์ weight์ parallelํ radial component๋ฅผ ์ ๊ฑฐํ๋ ์์ด๋์ด์ด๋ค.
Update vector์ ๊ธฐ์กด์ ๋ฐฉ๋ฒ์ผ๋ก ๊ณ์ฐํ ๋ค ($p_t$), weight vector์ parallelํ ์ฑ๋ถ์ projection์ผ๋ก ์ ๊ฑฐํ๋ค ($\Pi_{w_t}(p_t)$).
Weight vector์ parallelํ ์ฑ๋ถ์ loss minimization์๋ ๊ธฐ์ฌํ์ง ์๊ณ , weight norm์ ์ฆ๊ฐ์ํค๋ ๋ฐ์๋ง ๊ธฐ์ฌํ๊ธฐ ๋๋ฌธ์ด๋ค.
์ด ๋ ํ์ดํผํ๋ผ๋ฏธํฐ์ธ $\delta$๋ scale-invariantํ weight๋ฅผ detectํ๋ ์ญํ ์ ํ๋ค. ๋ ผ๋ฌธ์์๋ $\delta=0.1$๋ก ์ค์ ํ๋ค๊ณ ํ๋ค.
3. Experiments
Image classificaiton, object detection, audio classification, language modeling ๋ฑ ๋ค์ํ ๋๋ฉ์ธ๊ณผ ๋ชจ๋ธ์ ๋ํด ์คํ์ ์งํํ๋ค.
์คํ 1. Image classification - SGD, Adam๊ณผ ๋น๊ตํ์ ๋ SGDP, AdamP๊ฐ ํญ์ ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์๋ค.
์คํ 2. Object detection - Adam๋ณด๋ค AdamP๊ฐ ํญ์ ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์๋ค.
์คํ 3. Adversarial learning - model robustness๋ฅผ ํ๊ฐํ๊ธฐ ์ํด ์คํ์ ์งํํ๋ค. Adam๋ณด๋ค AdamP๊ฐ ๋ ๋์ robustness๋ฅผ ๊ฐ์ง ๊ฒ์ ํ์ธํ ์ ์๋ค.