state_dict
- torch.nn.Module ๋ชจ๋ธ์ ๊ฐ ๊ณ์ธต์ ํ์ต ๊ฐ๋ฅํ ๋งค๊ฐ๋ณ์(model.parameters())๋ค์ ๋งคํํ๋ dictionary ๊ฐ์ฒด.
- torch.optim ์ตํฐ๋ง์ด์ ๊ฐ์ฒด ๋ํ ์ตํฐ๋ง์ด์ ์ ์ํ์ ์ฌ์ฉ๋ ํ์ดํผํ๋ผ๋ฏธํฐ ์ ๋ณด๋ฅผ ํฌํจํ state_dict๋ฅผ ๊ฐ์ง.
state_dict ์ ์ฅํ๊ธฐ/๋ถ๋ฌ์ค๊ธฐ
์ ์ฅํ๊ธฐ
torch.save(model.state_dict(), PATH)
๋ถ๋ฌ์ค๊ธฐ
model = TheModelClass(*args, **kwargs)
model.load_sate_dict(torch.load(PATH)
model.eval()
- inference๋ฅผ ์ํด ํ์ต๋ ๋ชจ๋ธ์ ํ์ต๋ ๋งค๊ฐ๋ณ์๋ง state_dict๋ฅผ ์ด์ฉํ์ฌ ์ ์ฅํ๋ ๋ฐฉ๋ฒ.
- ๋ชจ๋ธ ์ ์ฅ ์ .pt ๋๋ .pth ํ์ฅ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ผ๋ฐ์ ์ธ ๊ท์น์.
- inference ์คํ ์ ๋ฐ๋์ model.eval()์ ํธ์ถํ์ฌ ํ๊ฐ ๋ชจ๋๋ก ์ค์ ํ์ฌ์ผ ํจ.
์ ์ฒด ๋ชจ๋ธ ์ ์ฅํ๊ธฐ/๋ถ๋ฌ์ค๊ธฐ
์ ์ฅํ๊ธฐ
torch.save(model, PATH)
๋ถ๋ฌ์ค๊ธฐ
model = torch.load(PATH)
model.eval()
- ์ ์ฒด ๋ชจ๋ธ์ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๋ ๋ฐฉ๋ฒ.
- ๋ชจ๋ธ ๊ทธ ์์ฒด๋ฅผ ์ ์ฅํ์ง ์๊ธฐ ๋๋ฌธ์ ์ง๋ ฌํ๋ ๋ฐ์ดํฐ๊ฐ ๋ชจ๋ธ์ ์ ์ฅํ ๋ ์ฌ์ฉํ ํน์ ํด๋์ค ๋ฐ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก์ ์ฝ๋งค์ธ๋ค๋ ๋จ์ ์ด ์์.
- ๋ชจ๋ธ ์ ์ฅ ์ .pt ๋๋ .pth ํ์ฅ์๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ด ์ผ๋ฐ์ ์ธ ๊ท์น์.
'๐ Python & library > PyTorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[PyTorch] Scheduler ์๊ฐํํ๊ธฐ (Visualize scheduler) (2) | 2021.11.24 |
---|---|
[PyTorch] ReduceLROnPlateau (0) | 2021.10.26 |
[PyTorch] CosineAnnealingLR, CosineAnnealingWarmRestarts (0) | 2021.10.14 |
[PyTorch] nn.ModuleList ๊ธฐ๋ฅ๊ณผ ์ฌ์ฉ ์ด์ (1) | 2021.08.04 |
[PyTorch] Livelossplot ์ฌ์ฉ์์ (0) | 2021.04.03 |