๐Ÿ Python & library/PyTorch

PyTorch 2.0์—์„œ ๋‹ฌ๋ผ์ง€๋Š” ์  - torch.compile

๋ณต๋งŒ 2023. 5. 6. 18:45
 

PyTorch 2.0

Overview

pytorch.org

 

์ถœ์ฒ˜: ํŒŒ์ดํ† ์น˜ ๊ณต์‹ ๋ธ”๋กœ๊ทธ

 

PyTorch 2.0์€ 22๋…„ 12์›” PyTorch Conference์—์„œ ๋ฐœํ‘œ๋˜์—ˆ๊ณ , 23๋…„ 3์›” ์ •์‹ ๋ฆด๋ฆฌ์ฆˆ ๋˜์—ˆ๋‹ค. ์ด์ „์˜ PyTorch 1.x ๋ฒ„์ „๋“ค๋ณด๋‹ค ๋น ๋ฅด๊ณ , Pythonicํ•˜๊ณ  Dynamicํ•˜๋‹ค๊ณ  ํ•œ๋‹ค. ์–ด๋–ค ์ ๋“ค์ด ๋‹ฌ๋ผ์กŒ์„์ง€ ํ•œ๋ฒˆ ์•Œ์•„๋ด…์‹œ๋‹ค.

 

 

torch.compile

torch.compile์€ PyTorch 2.0์˜ ๋ฉ”์ธ API์ด๋‹ค. ๋ชจ๋ธ์„ ๋ฏธ๋ฆฌ ์ปดํŒŒ์ผํ•˜์—ฌ ์†๋„๋ฅผ ๋†’์ด๋Š” ๊ธฐ์ˆ ์ด๋‹ค. torch.compile์€ TorchDynamo, AOTAutograd, PrimTorch, TorchInductor ๋„ค ๊ฐ€์ง€์˜ ์ƒˆ๋กœ์šด ๊ธฐ์ˆ ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋งŒ๋“ค์–ด์กŒ๋‹ค. ๊ฐ ๊ธฐ์ˆ ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์„ค๋ช…์€ ์—ฌ๊ธฐ์—์„œ ์ฐพ์•„๋ณผ ์ˆ˜ ์žˆ๋‹ค.

 

torch.compile์ด ๋‚˜์˜ค๊ฒŒ ๋œ ๋ฐฐ๊ฒฝ

 

์‚ฌ์šฉ๋ฒ•

 

torch.compile์€ ๊ธฐ์กด์˜ ๋ชจ๋ธ์— ํ•œ ์ค„๋งŒ ์ถ”๊ฐ€ํ•˜๋ฉด ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

 

compiled_model = torch.compile(model)

 

์ด๋ ‡๊ฒŒ ์ปดํŒŒ์ผ๋œ ๋ชจ๋ธ์€ ๊ธฐ์กด๊ณผ ๋™์ผํ•˜๊ฒŒ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค. ๋‹ค์Œ์€ ResNet18์— torch.compile์„ ์ ์šฉํ•˜๋Š” ์˜ˆ์‹œ ์ฝ”๋“œ์ด๋‹ค. torch.compile()์„ ํ•˜๋Š” ๋ถ€๋ถ„์„ ์ œ์™ธํ•˜๋ฉด ๊ธฐ์กด์˜ PyTorch ๋ชจ๋ธ๊ณผ ์‚ฌ์šฉ๋ฒ•์ด ๋™์ผํ•œ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

 

import torch
import torchvision.models as models

model = models.resnet18().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
compiled_model = torch.compile(model)

x = torch.randn(16, 3, 224, 224).cuda()
optimizer.zero_grad()
out = compiled_model(x)
out.sum().backward()
optimizer.step()

 

์ฒ˜์Œ compiled_model์„ ์‹คํ–‰ํ•  ๋•Œ์—๋Š” ๋ชจ๋ธ ์ปดํŒŒ์ผ์— ์‹œ๊ฐ„์ด ๊ฑธ๋ฆฌ์ง€๋งŒ, ์ดํ›„ ์‹คํ–‰๋“ค์€ ๋” ๋นจ๋ผ์ง„๋‹ค.

 

 

ํ•œ๊ฐ€์ง€ ์ฃผ์˜ํ•ด์•ผ ํ•  ์ ์€, ์ปดํŒŒ์ผ๋œ ๋ชจ๋ธ์„ ์ €์žฅํ•  ๋•Œ state_dict๋งŒ ์ €์žฅํ•  ์ˆ˜ ์žˆ๋‹ค. ์ฆ‰,

 

torch.save(optimized_model.state_dict(), "foo.pt")

 

์ด๊ฑด ๋˜๊ณ 

 

torch.save(optimized_model, "foo.pt")

 

์ด๊ฑด ์•ˆ๋œ๋‹ค.

 

 

์†๋„ ๋น„๊ต

 

torch.compile์„ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ์˜ ์†๋„๋ฅผ ๋น„๊ตํ•˜๊ธฐ ์œ„ํ•ด PyTorch๋Š” ์ด 163๊ฐ€์ง€์˜ ์˜คํ”ˆ์†Œ์Šค ๋ชจ๋ธ์— ๋Œ€ํ•œ ์‹คํ—˜์„ ์ง„ํ–‰ํ–ˆ๋‹ค. ์‹คํ—˜์— ์‚ฌ์šฉํ•œ ๋ชจ๋ธ์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

 

๋ชจ๋ธ์„ ์ „ํ˜€ ์ˆ˜์ •ํ•˜์ง€ ์•Š๊ณ  ์ฝ”๋“œ์— torch.compile๋งŒ ์ถ”๊ฐ€ํ–ˆ์„ ๋•Œ, ํ›ˆ๋ จ ์†๋„๊ฐ€ 43% ๋นจ๋ผ์กŒ๋‹ค๊ณ  ํ•œ๋‹ค (NVIDIA A100 GPU ๊ธฐ์ค€). ํ•œํŽธ NVIDIA 3090๊ณผ ๊ฐ™์€ ๋ฐ์Šคํฌํƒ‘์šฉ GPU๋ฅผ ์‚ฌ์šฉํ–ˆ์„ ๋•Œ๋Š” ์†๋„ ํ–ฅ์ƒ์ด ๊ทธ๋ณด๋‹ค๋Š” ๋œํ•˜๋‹ค๊ณ  ํ•œ๋‹ค.

 

 

Example

 

์•„๋ž˜ ๋ธ”๋กœ๊ทธ์—์„œ torch.compile์„ ์ด์šฉํ•œ ์‹คํ—˜์„ ์ง„ํ–‰ํ•˜๊ณ , ์†๋„๋ฅผ ๋น„๊ตํ–ˆ๋‹ค.

 

A Quick PyTorch 2 Tutorial - Zero to Mastery Learn PyTorch for Deep Learning

Learn important machine learning concepts hands-on by writing PyTorch code.

www.learnpytorch.io

 

NVIDIA TITAN RTX GPU๋ฅผ ์ด์šฉํ–ˆ๊ณ , CIFAT10 ๋ฐ์ดํ„ฐ์…‹์— ResNet50 ๋ชจ๋ธ๋กœ ์‹คํ—˜ํ–ˆ๋‹ค.

 

 

 

์™ผ์ชฝ์€ epoch=5, ์˜ค๋ฅธ์ชฝ์€ epoch=15์— ๋Œ€ํ•œ ์‹คํ—˜ ๊ฒฐ๊ณผ์ด๋‹ค. (multiple run์€ 5epoch ํ›ˆ๋ จ์„ ์—ฌ๋Ÿฌ๋ฒˆ ์‹คํ–‰์‹œ์ผฐ๋‹ค๋Š” ๋œป)

 

Epoch ์ˆ˜๊ฐ€ ๋Š˜์–ด๋‚ ์ˆ˜๋ก ์ปดํŒŒ์ผ๋œ ๋ชจ๋ธ์˜ ์†๋„ ํ–ฅ์ƒ์ด ๋” ํฐ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด๋Š” ์ฒซ ์ปดํŒŒ์ผ์— ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆฌ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

์†๋„ ์ฐจ์ด๊ฐ€ ํฌ์ง€ ์•Š์€ ๊ฒƒ์€ ๋ฐ์Šคํฌํƒ‘์šฉ GPU๋ฅผ ์‚ฌ์šฉํ–ˆ๊ธฐ ๋•Œ๋ฌธ์ด๊ณ , A100๋“ฑ์˜ ์‚ฐ์—…์šฉ GPU๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์†๋„ํ–ฅ์ƒ์ด ๋” ํฌ๋‹ค๊ณ  ํ•œ๋‹ค.

 

๋ฐ˜์‘ํ˜•