NVIDIA์ Baidu์์ ์ฐ๊ตฌํ๊ณ ICLR 2018์ ๋ฐํ๋ ๋ ผ๋ฌธ์ธ
Mixed Precision Training์ ๋ฐํ์ผ๋ก ์ ๋ฆฌํ ๊ธ์ ๋๋ค.
๋ฅ๋ฌ๋ ํ์ต ๊ณผ์ ์์ Mixed Precision์ ์ด์ฉํ์ฌ GPU resource๋ฅผ ํจ์จ์ ์ผ๋ก ์ฌ์ฉํ ์ ์๋ ๋ฐฉ๋ฒ์ ๋๋ค.
(NVIDIA ๋ธ๋ก๊ทธ ์ ๋ฆฌ๊ธ: developer.nvidia.com/blog/mixed-precision-training-deep-neural-networks/)
Floating Point Format
์ค์๋ฅผ ์ปดํจํฐ๋ก ๋ํ๋ด๋ ๋ฐฉ๋ฒ์ ๊ณ ์ ์์์ (Fixed Pint) ๋ฐฉ์๊ณผ ๋ถ๋์์์ (Floating Point) ๋ฐฉ์์ด ์กด์ฌํฉ๋๋ค.
(๋ถ๋์์์ ๋ฐฉ์์ ๋ ๋์ด ์์์ ๋ฐฉ์์ด๋ผ๊ณ ๋ ํ๋ค๊ณ ํฉ๋๋ค. ๊ท์ฝ๋ค์)
๊ณ ์ ์์์ ๋ฐฉ์์ ์ ์๋ถ์ ์์๋ถ๋ฅผ ๋ด์ ๋นํธ์ ์๋ฅผ ๊ณ ์ ํด์ ์ฌ์ฉํ๋ ๋ฐฉ์.
์ ํํ๊ณ ์ฐ์ฐ์ด ๋น ๋ฅด์ง๋ง ํํ ๊ฐ๋ฅํ ๋ฒ์๊ฐ ๋ถ๋์์์ ๋ฐฉ์์ ๋นํด ์ข์ต๋๋ค.
๋ถ๋์์์ ๋ฐฉ์์ ํํํ๊ณ ์ ํ๋ ์๋ฅผ ์ ๊ทํํ์ฌ ๊ฐ์๋ถ(exponent)์ ์ง์๋ถ(fraction/mantissa)๋ฅผ ๋ฐ๋ก ์ ์ฅํ๋ ๋ฐฉ์์ ๋๋ค.
์๋ฅผ ๋ค์๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
5.6875๋ฅผ 2์ง๋ฒ์ผ๋ก ๋ํ๋ด๋ฉด 101.1011
์ด๋ฅผ 1.011011 * 2^2๋ก ๋ํ๋ด๋ ๊ฒ์ ์ ๊ทํ๋ผ๊ณ ํฉ๋๋ค. (์ ์๋ถ์ ํ ์๋ฆฌ๋ง ๋จ๊ธฐ๋ ๊ฒ)
์ด ๋ 1.011011์ ๊ฐ์๋ถ, 2์ ์ง์์ธ 2๋ฅผ ์ ์๋ถ๋ผ๊ณ ํ๊ณ ,
๋ถ๋์์์ ๋ฐฉ์์ ์ด ๊ฐ์๋ถ์ ์ ์๋ถ๋ฅผ ๊ฐ๊ฐ ์ ์ฅํ๋ ๋ฐฉ์์
๋๋ค.
๋ถ๋์์์ ๋ฐฉ์์ IEEE754 ํ์ค์ด ๊ฐ์ฅ ๋๋ฆฌ ์ฐ์ด๊ณ ์๋๋ฐ, ๊ทธ ์ข ๋ฅ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- FP32 (Single Precision, ๋จ์ ๋ฐ๋)
- FP64 (Double Precision)
- FP128 (Quadruple Precision)
- FP16 (Half Precision)
FP ๋ค์ ์ซ์๋ ๋ช bit๋ฅผ ์ด์ฉํ๋์ง๋ฅผ ์๋ฏธํฉ๋๋ค. FP32๋ 32bit๋ฅผ ์ด์ฉํ์ฌ ์ค์๋ฅผ ์ ์ฅํ๋ ๊ฒ.
๋น์ฐํ ์ด์ฉํ๋ ๋นํธ ์๊ฐ ๋ง์์๋ก ๋ ๋์ ์ ๋ฐ๋(Precision)๋ก ์ค์๋ฅผ ์ ์ฅํ ์ ์์ต๋๋ค.
(FP64๋ FP32์ ๋ ๋ฐฐ์ ์ ๋ฐ๋๋ฅผ ๊ฐ๋๋ค๋ ๋ป์์ Double Precision์ด๋ผ๊ณ ๋ถ๋ฅด๋ ๊ฒ์ด๊ฒ์ฃ ?)
ํ๋ ๋ฅ๋ฌ๋ ํ์ต ๊ณผ์ ์์๋ Single Precision(FP32) ํฌ๋งท์ ์ฌ์ฉํฉ๋๋ค.
(weight ์ ์ฅ, gradient ๊ณ์ฐ ๋ฑ)
ํ์ง๋ง Single Precision(FP32) ๋ฐฉ์์ด ์๋, Half Precision(FP16) ๋ฐฉ์์ผ๋ก ํ์ต์ ์งํํ๋ค๋ฉด,
ํ์ ๋ GPU์ resource๋ฅผ ์๋ ์ ์์ง ์์๊น์?
Mixed Precision
Half Precision(FP16) ๋ฐฉ์์ ์ด์ฉํด ํ์ต์ ์งํํ๋ฉด ๋น์ฐํ ์ ์ฅ๊ณต๊ฐ๋ ์๋ผ๊ณ , ์ฐ์ฐ ์๋๋ ๋นจ๋ผ์ง๋๋ค.
ํ์ง๋ง Half Precision ๋ฐฉ์์ Single Precision ๋ฐฉ์๋ณด๋ค ์ ๋ฐ๋๊ฐ ํ์ ํ ๋จ์ด์ง์ฃ .
๋ฐ๋ผ์ gradient๊ฐ ๋๋ฌด ํฐ ๊ฒฝ์ฐ, ํน์ ๋๋ฌด ์์ ๊ฒฝ์ฐ ์ค์ฐจ๊ฐ ๋ฐ์ํ๊ฒ ๋๊ณ , ์ด ์ค์ฐจ๋ ๋์ ๋์ด ๊ฒฐ๊ตญ ํ์ต์ด ์ ์งํ๋์ง ์์ต๋๋ค.
์ ๊ทธ๋ฆผ์์ ๊ฒ์ ์ ์ ์ FP32๋ฅผ ์ด์ฉํด ํ์ต์ํจ ๊ฒฐ๊ณผ, ํ์ ์ ์ FP16์ ์ด์ฉํด ํ์ต์ํจ ๊ฒฐ๊ณผ์ ๋๋ค.
Y ์ถ์ training loss์ธ๋ฐ, FP16์ผ๋ก ํ์ต์ํค๋ ๊ฒฝ์ฐ loss๊ฐ ์ค์ด๋ค๋ค๊ฐ ์๋ ดํ์ง ๋ชปํ๊ณ ๋ค์ ์ปค์ง๋ ๊ฒ์ ํ์ธํ ์ ์์ฃ .
์ ๊ทธ๋ฆผ์ ์ค์ ๋ก FP32๋ฅผ ์ด์ฉํด ๋คํธ์ํฌ๋ฅผ ํ์ต์ํค๊ณ , ์์์ gradient ๊ฐ๋ค์ sampling ํ ๊ฒฐ๊ณผ์ ๋๋ค.
๋นจ๊ฐ ์ ์ผ์ชฝ์ gradient๋ค์ FP16์์ ํํ ๋ถ๊ฐ๋ฅํ ์ ๋ฐ๋์ ๊ฐ๋ค์ด๊ธฐ ๋๋ฌธ์, FP16์์๋ 0์ผ๋ก ํํ๋๊ณ , ์ด๋ฌํ ์ค์ฐจ๋ค์ด ๋์ ๋์ด ๋ชจ๋ธ ํ์ต์ ์ด๋ ค์์ด ๋ฐ์ํฉ๋๋ค.
๋ณธ ๋ ผ๋ฌธ์์ ์ ์ํ Mixed Precision Training์ FP32์ FP16์ ํจ๊ป ์ฌ์ฉํ์ฌ ์ด๋ฅผ ๊ทน๋ณตํฉ๋๋ค.
= Mixed Precision
Implementation
๋ฐฉ๋ฒ์ ๊ต์ฅํ ๊ฐ๋จํฉ๋๋ค. ์ค์ ๋ก gradient ๊ฐ๋ค์ด ๋งค์ฐ ์์ ๊ฐ์ ๋ชฐ๋ ค ์์ด์ FP16์ผ๋ก casting ์ 0์ด ๋์ด ๋ฒ๋ฆฝ๋๋ค.
์ฆ, FP16์ ํํ ๊ฐ๋ฅ ๋ฒ์ ๋ฐ์ gradient๊ฐ ๋ถํฌํด์ ์๊ธด ๋ฌธ์ ์ธ๋ฐ,
๊ทธ๋ ๋ค๋ฉด ๋จ์ํ Scaling์ ํตํด gradient๋ฅผ FP16์ ํํ ๊ฐ๋ฅ ๋ฒ์ ์์ผ๋ก ์ด๋์์ผ ์ฃผ๋ฉด ๋์ง ์์๊น์?
์ข ๋ ์์ธํ ๋ํ๋ด๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
Step 1. FP32 weight์ ๋ํ FP16 copy weight์ ๋ง๋ ๋ค.
(์ด FP16 copy weight์ forward pass, backward pass์ ์ด์ฉ๋๋ค.)
Step 2. FP16 copy weight์ ์ด์ฉํด forward pass๋ฅผ ์งํํ๋ค.
Step 3. forward pass๋ก ๊ณ์ฐ๋ FP16 prediction ๊ฐ์ FP32๋ก castingํ๋ค.
Step 4. FP32 prediction์ ์ด์ฉํด FP32 loss๋ฅผ ๊ณ์ฐํ๊ณ , ์ฌ๊ธฐ์ scaling factor S๋ฅผ ๊ณฑํ๋ค.
Step 5. scaled FP32 loss๋ฅผ FP16์ผ๋ก castingํ๋ค.
Step 6. scaled FP16 loss๋ฅผ ์ด์ฉํ์ฌ backward propagation์ ์งํํ๊ณ , gradient๋ฅผ ๊ณ์ฐํ๋ค.
Step 7. FP16 gradient๋ฅผ FP32๋ก castingํ๊ณ , ์ด๋ฅผ scaling factor S๋ก ๋ค์ ๋๋๋ค.
(chain rule์ ์ํด ๋ชจ๋ gradient๋ ๊ฐ์ ํฌ๊ธฐ๋ก scaling๋ ์ํ์)
Step 8. FP32 gradient๋ฅผ ์ด์ฉํด FP32 weight๋ฅผ updateํ๋ค.
์ ๋ฆฌํ์๋ฉด, FP32 weight์ ๊ณ์ ์ ์ฅํด ๋๊ณ ,
FP16 copy weight๋ฅผ ๋ง๋ค์ด ์ด๋ฅผ ์ด์ฉํด forward/backward pass๋ฅผ ์งํํ๋ ๊ฒ์ ๋๋ค.
FP16 copy weight์ผ๋ก ์ป์ gradient๋ฅผ ์ด์ฉํด FP32 weight๋ฅผ updateํฉ๋๋ค.
* ์ด ๋ Scaling Factor์ ์ด๋ป๊ฒ ์ ํ ๊น์?
๋ ผ๋ฌธ์์๋ ๋จ์ํ ๊ฒฝํ์ ์ธ ๊ฐ์ ์ ํํ๊ฑฐ๋,
gradient์ ํต๊ณํ๊ฐ ๊ฐ๋ฅํ ๊ฒฝ์ฐ gradient์ maximum absolute value๊ฐ 65,504(FP16์ด ํํ๊ฐ๋ฅํ ์ต๋๊ฐ)๊ฐ ๋๋๋ก ๋ง์ถฐ ์ฃผ๋ฉด ๋๋ค๊ณ ํฉ๋๋ค.
Scaling factor์ด ํฌ๋ค๊ณ ํด์ ๋์ ์ ์ ์์ง๋ง, overflow๊ฐ ์ผ์ด๋์ง ์๋๋ก ์ฃผ์ํด์ผ ํฉ๋๋ค!
Experiment & Result
Classification, detection ๋ฑ ๊ฐ๋จํ task๋ถํฐ ์์ํด์ GAN ๊น์ง ์์ฃผ ๋ค์ํ ์คํ์ ์งํํ์ต๋๋ค. (Method๊ฐ ๋๋ฌด ๋จ์ํ๊ธฐ ๋๋ฌธ์ผ๊น์?)
๋ช ๊ฐ์ง ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ๋๋ฆฌ๊ฒ ์ต๋๋ค.
* Baseline: FP32 / MP: Mixed Precision(FP32+FP16)
Mixed Precision์์ ์ฑ๋ฅ์ด ์คํ๋ ค ์ค๋ฅธ ๊ฒ๋ ์๊ณ , ์ ๋ฐ์ ์ผ๋ก FP32์ ๋ค์ง์ง ์๋ ์ฑ๋ฅ์ ๋ณด์ฌ์ค ๊ฒ ๊ฐ์ฃ ?
๋ ผ๋ฌธ์์ ์์ธํ ์คํ setting๊ณผ ๋ ๋ง์ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ์ค ์ ์์ต๋๋ค.
PyTorch Implementation
PyTorch์์ ๊ณต์์ ์ผ๋ก Mixed Precision Training์ ์ง์ํฉ๋๋ค.
Automatic Mixed Precision(AMP) ๋ผ๋ ์ด๋ฆ์ผ๋ก, ๋ช ์ค์ ์ฝ๋๋ง ์ถ๊ฐํ๋ฉด ์์ฝ๊ฒ ์ฌ์ฉ ๊ฐ๋ฅํฉ๋๋ค.
๊ณต์ ๋ฌธ์: pytorch.org/docs/stable/amp.html
Github์ ์ ์ ๋ฆฌ๋ ์ฝ๋๊ฐ ์์ด ๊ฐ์ ธ์์ต๋๋ค.
์ถ์ฒ: github.com/hoya012/automatic-mixed-precision-tutorials-pytorch
์ผ๋ฐ์ ์ธ ํ์ต ์ฝ๋
for batch_idx, (inputs, labels) in enumerate(data_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
AMP๋ฅผ ์ ์ฉํ ์ฝ๋
""" define loss scaler for automatic mixed precision """
# Creates a GradScaler once at the beginning of training.
scaler = torch.cuda.amp.GradScaler()
for batch_idx, (inputs, labels) in enumerate(data_loader):
optimizer.zero_grad()
with torch.cuda.amp.autocast():
# Casts operations to mixed precision
outputs = model(inputs)
loss = criterion(outputs, labels)
# Scales the loss, and calls backward()
# to create scaled gradients
scaler.scale(loss).backward()
# Unscales gradients and calls
# or skips optimizer.step()
scaler.step(self.optimizer)
# Updates the scale for next iteration
scaler.update()
ํ์ต์ ์์ํ๊ธฐ ์ scaler์ ์ ์ธํด์ฃผ๊ณ ,amp.autocast()๋ฅผ ์ด์ฉํ์ฌ casting ๊ณผ์ ์ ๊ฑฐ์น๋ฉฐ foward pass๋ฅผ ์งํํฉ๋๋ค.backward pass, optimization, weight update ๋ฑ์ ๊ณผ์ ์ด ๋ชจ๋ scaler์ ํตํด ์งํ๋๋ ํํ์ธ ๊ฒ ๊ฐ์ฃ ?
์ Github์์ ์ค์ ๋ก Torch์ AMP๋ฅผ ์ด์ฉํด ์คํ์ ์งํํด ๋ณด์ จ์ด์.
GTX 1080 Ti์ RTX 2080 Ti๋ฅผ ์ด์ฉํ์ฌ ์คํ์ ์งํํด ๋ณด์ จ๋ค๊ณ ํฉ๋๋ค.
1080 Ti๋ฅผ ์ด์ฉํ์ ๋, 2080 Ti๋ฅผ ์ด์ฉํ์ ๋ ๋ชจ๋ GPU ๋ฉ๋ชจ๋ฆฌ๋ ๋น์ฐ ์ ๊ฒ ์ฌ์ฉํ ๊ฒ์ ๋ณผ ์ ์๊ณ ,
Training Time์ 2080 Ti๋ฅผ ์ด์ฉํ์ ๋์๋ง ์ค์ด๋ค์๋ค์.
Test Accuracy๋ ๋ ๊ฒฝ์ฐ ๋ชจ๋ ์ค์ด๋ค์ง ์์ ์ฑ๋ฅ์ ํ๋ ์์์์ ์ ์ ์์ต๋๋ค.
RTX 2080 Ti์ Tensor Core์ด ํ์ฌ๋์ด FP16์ ๊ณ์ฐ์ด ํ๊ธฐ์ ์ผ๋ก ๋น ๋ฅด๋ค๊ณ ํฉ๋๋ค.
๋๋ฌธ์ Tensor Core์ด ํ์ฌ๋ GPU๋ฅผ ์ฌ์ฉํ์ ๋ Torch AMP๊ฐ ์๊ฐ ์ธก๋ฉด์์๋ ๋น์ ๋ฐํ ๊ฒ ๊ฐ๋ค์.
(Tensor Core์ ๋ํ ์์ธํ ์ค๋ช ์ www.nvidia.com/ko-kr/data-center/tensor-cores/)
Tensor Core์ TF32๋ผ๋ ์์ฒด ์ ๋ฐ๋๋ฅผ ์ด์ฉํด์ FP32๋ณด๋ค ์ต๋ 20๋ฐฐ๊น์ง ๊ฐ์์ด ๊ฐ๋ฅํ๋ค๊ณ ํ๋๋ฐ
์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด๋ฉด FP32๋ฅผ ์ด์ฉํ Baseline์์๋ ์๋ ๋ฉด์์ ํฐ ์ฐจ์ด๊ฐ ์๋ค์.. ์์ผ๊น์?
Conclusion
GPU์ resource๋ฅผ ์๋ ์ ์๊ณ , ํ์ต ์๊ฐ๊น์ง ๋จ์ถ์ํฌ ์ ์๋ Mixed Precision Training์ ๋ํด ๋ฆฌ๋ทฐํด ๋ณด์์ต๋๋ค.
PyTorch์์ ๋งค์ฐ ๊ฐ๋จํ๊ฒ ๊ตฌํ๋ ๊ฐ๋ฅํด์, ์ ๋ง ์ ์ธ ์ด์ ๊ฐ ์์ ๊ฒ์ฒ๋ผ ๋๊ปด์ง๋ค์.
FP32์ ์ ๋ฐ๋๋ ์ ์งํ๋ฉด์ FP16์ ์ด์ฉํด ์ ์ฅ๊ณต๊ฐ์ ์๋ผ๋ ๋ฐฉ๋ฒ์ด์๋๋ฐ,
๊ทธ๋ ๋ค๋ฉด FP32๋ฅผ ์ด์ฉํด์ FP64์ ์ ๋ฐ๋๋ฅผ ๊ตฌํํ๋ ๋ฐฉ๋ฒ๋ ๊ฐ๋ฅํ์ง ์์๊น์?
์ ๋ฐ๋๊ฐ ์ค์ํ task์์๋ ํ๋ฒ์ฏค ์๋ํด ๋ณด๊ณ ์ถ์ต๋๋ค.