PyTorch 2.0์ 22๋ 12์ PyTorch Conference์์ ๋ฐํ๋์๊ณ , 23๋ 3์ ์ ์ ๋ฆด๋ฆฌ์ฆ ๋์๋ค. ์ด์ ์ PyTorch 1.x ๋ฒ์ ๋ค๋ณด๋ค ๋น ๋ฅด๊ณ , Pythonicํ๊ณ Dynamicํ๋ค๊ณ ํ๋ค. ์ด๋ค ์ ๋ค์ด ๋ฌ๋ผ์ก์์ง ํ๋ฒ ์์๋ด ์๋ค.
torch.compile
torch.compile์ PyTorch 2.0์ ๋ฉ์ธ API์ด๋ค. ๋ชจ๋ธ์ ๋ฏธ๋ฆฌ ์ปดํ์ผํ์ฌ ์๋๋ฅผ ๋์ด๋ ๊ธฐ์ ์ด๋ค. torch.compile์ TorchDynamo, AOTAutograd, PrimTorch, TorchInductor ๋ค ๊ฐ์ง์ ์๋ก์ด ๊ธฐ์ ์ ๊ธฐ๋ฐ์ผ๋ก ๋ง๋ค์ด์ก๋ค. ๊ฐ ๊ธฐ์ ์ ๋ํ ์์ธํ ์ค๋ช ์ ์ฌ๊ธฐ์์ ์ฐพ์๋ณผ ์ ์๋ค.
์ฌ์ฉ๋ฒ
torch.compile์ ๊ธฐ์กด์ ๋ชจ๋ธ์ ํ ์ค๋ง ์ถ๊ฐํ๋ฉด ์ฌ์ฉํ ์ ์๋ค.
compiled_model = torch.compile(model)
์ด๋ ๊ฒ ์ปดํ์ผ๋ ๋ชจ๋ธ์ ๊ธฐ์กด๊ณผ ๋์ผํ๊ฒ ์ฌ์ฉํ ์ ์๋ค. ๋ค์์ ResNet18์ torch.compile์ ์ ์ฉํ๋ ์์ ์ฝ๋์ด๋ค. torch.compile()์ ํ๋ ๋ถ๋ถ์ ์ ์ธํ๋ฉด ๊ธฐ์กด์ PyTorch ๋ชจ๋ธ๊ณผ ์ฌ์ฉ๋ฒ์ด ๋์ผํ ๊ฒ์ ํ์ธํ ์ ์๋ค.
import torch
import torchvision.models as models
model = models.resnet18().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
compiled_model = torch.compile(model)
x = torch.randn(16, 3, 224, 224).cuda()
optimizer.zero_grad()
out = compiled_model(x)
out.sum().backward()
optimizer.step()
์ฒ์ compiled_model์ ์คํํ ๋์๋ ๋ชจ๋ธ ์ปดํ์ผ์ ์๊ฐ์ด ๊ฑธ๋ฆฌ์ง๋ง, ์ดํ ์คํ๋ค์ ๋ ๋นจ๋ผ์ง๋ค.
ํ๊ฐ์ง ์ฃผ์ํด์ผ ํ ์ ์, ์ปดํ์ผ๋ ๋ชจ๋ธ์ ์ ์ฅํ ๋ state_dict๋ง ์ ์ฅํ ์ ์๋ค. ์ฆ,
torch.save(optimized_model.state_dict(), "foo.pt")
์ด๊ฑด ๋๊ณ
torch.save(optimized_model, "foo.pt")
์ด๊ฑด ์๋๋ค.
์๋ ๋น๊ต
torch.compile์ ์ฌ์ฉํ์ ๋์ ์๋๋ฅผ ๋น๊ตํ๊ธฐ ์ํด PyTorch๋ ์ด 163๊ฐ์ง์ ์คํ์์ค ๋ชจ๋ธ์ ๋ํ ์คํ์ ์งํํ๋ค. ์คํ์ ์ฌ์ฉํ ๋ชจ๋ธ์ ๋ค์๊ณผ ๊ฐ๋ค.
- HuggingFace Transformer ๋ชจ๋ธ 46๊ฐ์ง
- TIMM ๋ชจ๋ธ 61๊ฐ์ง
- TorchBench ๋ชจ๋ธ 56๊ฐ์ง
๋ชจ๋ธ์ ์ ํ ์์ ํ์ง ์๊ณ ์ฝ๋์ torch.compile๋ง ์ถ๊ฐํ์ ๋, ํ๋ จ ์๋๊ฐ 43% ๋นจ๋ผ์ก๋ค๊ณ ํ๋ค (NVIDIA A100 GPU ๊ธฐ์ค). ํํธ NVIDIA 3090๊ณผ ๊ฐ์ ๋ฐ์คํฌํ์ฉ GPU๋ฅผ ์ฌ์ฉํ์ ๋๋ ์๋ ํฅ์์ด ๊ทธ๋ณด๋ค๋ ๋ํ๋ค๊ณ ํ๋ค.
Example
์๋ ๋ธ๋ก๊ทธ์์ torch.compile์ ์ด์ฉํ ์คํ์ ์งํํ๊ณ , ์๋๋ฅผ ๋น๊ตํ๋ค.
NVIDIA TITAN RTX GPU๋ฅผ ์ด์ฉํ๊ณ , CIFAT10 ๋ฐ์ดํฐ์ ์ ResNet50 ๋ชจ๋ธ๋ก ์คํํ๋ค.
์ผ์ชฝ์ epoch=5, ์ค๋ฅธ์ชฝ์ epoch=15์ ๋ํ ์คํ ๊ฒฐ๊ณผ์ด๋ค. (multiple run์ 5epoch ํ๋ จ์ ์ฌ๋ฌ๋ฒ ์คํ์์ผฐ๋ค๋ ๋ป)
Epoch ์๊ฐ ๋์ด๋ ์๋ก ์ปดํ์ผ๋ ๋ชจ๋ธ์ ์๋ ํฅ์์ด ๋ ํฐ ๊ฒ์ ํ์ธํ ์ ์๋ค. ์ด๋ ์ฒซ ์ปดํ์ผ์ ์๊ฐ์ด ์ค๋ ๊ฑธ๋ฆฌ๊ธฐ ๋๋ฌธ์ด๋ค.
์๋ ์ฐจ์ด๊ฐ ํฌ์ง ์์ ๊ฒ์ ๋ฐ์คํฌํ์ฉ GPU๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ด๊ณ , A100๋ฑ์ ์ฐ์ ์ฉ GPU๋ฅผ ์ฌ์ฉํ๋ฉด ์๋ํฅ์์ด ๋ ํฌ๋ค๊ณ ํ๋ค.
'๐ Python & library > PyTorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[PyTorch] Autograd ์๋๋ฐฉ์ ์์๋ณด๊ธฐ (1) | 2023.12.02 |
---|---|
[PyTorch] tensor.detach()์ ๊ธฐ๋ฅ๊ณผ ์์ ์ฝ๋ (0) | 2022.10.30 |
[PyTorch] nn.Embedding ์ด๊ธฐํํ๊ธฐ (initialization) (0) | 2022.10.27 |
Numpy & PyTorch๋ก 2D fourier transform, inverse fourier transformํ๊ธฐ (1) | 2022.08.27 |
[PyTorch] make_grid๋ก ์ฌ๋ฌ ๊ฐ์ ์ด๋ฏธ์ง ํ๋ฒ์ plotํ๊ธฐ (0) | 2022.07.29 |