๐Ÿ Python & library/PyTorch

[PyTorch] ๋ชจ๋ธ ์ €์žฅํ•˜๊ธฐ & ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

๋ณต๋งŒ 2020. 2. 3. 14:06

state_dict

- torch.nn.Module ๋ชจ๋ธ์˜ ๊ฐ ๊ณ„์ธต์˜ ํ•™์Šต ๊ฐ€๋Šฅํ•œ ๋งค๊ฐœ๋ณ€์ˆ˜(model.parameters())๋“ค์„ ๋งคํ•‘ํ•˜๋Š” dictionary ๊ฐ์ฒด.

- torch.optim ์˜ตํ‹ฐ๋งˆ์ด์ € ๊ฐ์ฒด ๋˜ํ•œ ์˜ตํ‹ฐ๋งˆ์ด์ €์˜ ์ƒํƒœ์™€ ์‚ฌ์šฉ๋œ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์ •๋ณด๋ฅผ ํฌํ•จํ•œ state_dict๋ฅผ ๊ฐ€์ง.

 


 

state_dict ์ €์žฅํ•˜๊ธฐ/๋ถˆ๋Ÿฌ์˜ค๊ธฐ

์ €์žฅํ•˜๊ธฐ

torch.save(model.state_dict(), PATH)

๋ถˆ๋Ÿฌ์˜ค๊ธฐ

model = TheModelClass(*args, **kwargs)
model.load_sate_dict(torch.load(PATH)
model.eval()

- inference๋ฅผ ์œ„ํ•ด ํ•™์Šต๋œ ๋ชจ๋ธ์˜ ํ•™์Šต๋œ ๋งค๊ฐœ๋ณ€์ˆ˜๋งŒ state_dict๋ฅผ ์ด์šฉํ•˜์—ฌ ์ €์žฅํ•˜๋Š” ๋ฐฉ๋ฒ•.

- ๋ชจ๋ธ ์ €์žฅ ์‹œ .pt ๋˜๋Š” .pth ํ™•์žฅ์ž๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ผ๋ฐ˜์ ์ธ ๊ทœ์น™์ž„.

- inference ์‹คํ–‰ ์‹œ ๋ฐ˜๋“œ์‹œ model.eval()์„ ํ˜ธ์ถœํ•˜์—ฌ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ์„ค์ •ํ•˜์—ฌ์•ผ ํ•จ.

 


 

์ „์ฒด ๋ชจ๋ธ ์ €์žฅํ•˜๊ธฐ/๋ถˆ๋Ÿฌ์˜ค๊ธฐ

์ €์žฅํ•˜๊ธฐ

torch.save(model, PATH)

๋ถˆ๋Ÿฌ์˜ค๊ธฐ

model = torch.load(PATH)
model.eval()

- ์ „์ฒด ๋ชจ๋ธ์„ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๋Š” ๋ฐฉ๋ฒ•.

- ๋ชจ๋ธ ๊ทธ ์ž์ฒด๋ฅผ ์ €์žฅํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ์ง๋ ฌํ™”๋œ ๋ฐ์ดํ„ฐ๊ฐ€ ๋ชจ๋ธ์„ ์ €์žฅํ•  ๋•Œ ์‚ฌ์šฉํ•œ ํŠน์ • ํด๋ž˜์Šค ๋ฐ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ์— ์–ฝ๋งค์ธ๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ์Œ.

- ๋ชจ๋ธ ์ €์žฅ ์‹œ .pt ๋˜๋Š” .pth ํ™•์žฅ์ž๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ด ์ผ๋ฐ˜์ ์ธ ๊ทœ์น™์ž„.

๋ฐ˜์‘ํ˜•