PyTorch ๋ชจ๋ธ์ ์๊ฐํํ ์ ์๋ ํด ์ธ๊ฐ์ง๋ฅผ ์๊ฐํ๋ค.
์ถ์ฒ: https://stackoverflow.com/questions/52468956/how-do-i-visualize-a-net-in-pytorch
How do I visualize a net in Pytorch?
import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data import torchvision.models as models import torchvision.datasets as dset import torchvision.transforms as
stackoverflow.com
1. Torchviz
https://github.com/szagoruyko/pytorchviz
backward pass๋ฅผ ์ด์ฉํด ๋ชจ๋ธ์ ์๊ฐํํ๋ค. ์ค๊ฐ์ ์ด๋ป๊ฒ gradient๋ฅผ ๊ณ์ฐํ๊ณ backward ์ฐ์ฐ์ ์งํํ๋์ง ๋ชจ๋ ๋ณด์ฌ์ค๋ค.
Install:
pip install torchviz
Usage:
from torchviz import make_dot
make_dot(y_pred, params=dict(model.named_parameters())
Example:
2. HiddenLayer
https://github.com/waleedka/hiddenlayer
Install:
pip install hiddenlayer
Usage:
import hiddenlayer as hl
hl.build_graph(model, input)
Example:
๊ธฐ๋ณธ์ ์ธ ์ฌ์ฉ๋ฒ์ ์์ ๊ฐ์๋ฐ ๊ทธ๋ํ ์๊น์ ๋ฐ๊พธ๊ฑฐ๋, ์ฌ๋ฌ node๋ค์ ํฉ์ณ ํ๋์ block์ผ๋ก ๋ง๋ ๋ค๊ฑฐ๋ ํ๋ customizing option๋ค์ด ์๋ค.
๊ทธ๋ฐ๋ฐ ์ฐ์ฐ๊ทธ๋ํ์ ๋๋ฌด ๋ถํ์ํ ๋ถ๋ถ์ด ๋ง๊ธด ํ๋ค. BN์ constant-cast-add-div ์ ๊ณผ์ ์ผ๋ก ๋ํ๋ธ๋ค๋์ง..
Hiddenlayers๋ ๋ชจ๋ธ ์๊ฐํ ์ด์ธ์๋ training ๊ณผ์ ์์ ์ฌ์ฉํ ์ ์๋ ๋ค์ํ ์๊ฐํ ๊ธฐ๋ฅ๋ค์ ์ ๊ณตํ๋ค.
3. Netron
https://github.com/lutzroeder/netron
desktop application์ผ๋ก ๋ค์ด๋ก๋ ๋ฐ์ ์๋ ์๊ณ , browser version์ ์ด์ฉํ ์๋ ์๋ค.
PyTorch ๋ชจ๋ธ๋ experimental support๋ฅผ ์งํ ์ค์ด๋ผ๊ณ ํด์ torch.save๋ก ์ ์ฅํ ๋ชจ๋ธ๊ณผ state_dict๋ฅผ ๊ฐ๊ฐ ๋ฃ์ด๋ดค๋๋ฐ ๋๋ค ์ข.. ์ ์๋ฏธํ ์ ๋ณด๊ฐ ์๋๋ค.
ONNX format์ผ๋ก ๋ชจ๋ธ์ ์ ์ฅํ ํ ๋ถ๋ฌ์ค๋ฉด ์ ๋ณด์ธ๋ค๋ ๊ฒ ๊ฐ๋ค. ๋ด ๋ชจ๋ธ์ complex number์ input์ผ๋ก ๋ฐ๋๋ฐ ONNX format์ด complex format์ ์ง์์ ์ํ๋ค๊ณ ํด์ ๊ทธ๋ฅ ์ํด๋ดค์.
๊ทธ๋ฌ๋ ์ผ๋จ browser๋ก๋ ์ฌ์ฉํ ์๊ฐ ์๊ณ , zoom์ด๋ ์ด๋ ๋ฑ UI๊ฐ ์ ๋์ด ์์ด ์ฌ์ฉํ๊ธฐ ๋งค์ฐ ํธํ ๊ฒ ๊ฐ๋ค.
๋ด๊ฐ ์ฌ์ฉํ ๋ชจ๋ธ์ complex๋ฅผ input์ผ๋ก ๋ฐ๋ ๋ชจ๋ธ์ธ๋ฐ, ONNX format์ด complex๋ฅผ ์ง์ํ์ง ์์์ Netron์ .pth ํ์ผ๋ก๋ง ํด๋ณผ ์ ์์๊ณ , hiddenlayer ์ญ์ ์ค๊ฐ์ ONNX๋ก ๋ณํํ๋ ๊ณผ์ ์ ๊ฑฐ์ณ์ ์ฌ์ฉ์ ๋ชปํด๋ดค๋ค.
๊ฒฐ๋ก ์ ๊ฐ์ฅ ๊ฐ๋จํ ๊ฒ์ Torchviz, customizing์ด ๊ฐ๋ฅํ๊ณ ๊ธฐ๋ฅ์ด ๋ง์ ๊ฒ์ HiddenLayer, browser-base๋ผ ํธ๋ฆฌํ ๊ฒ์ Netron์ธ๋ฏ โ
'๐ Python & library > PyTorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[PyTorch] Weight clipping (0) | 2022.01.29 |
---|---|
[PyTorch] Enable anomaly detection (torch.autograd.detect_anomaly() / torch.autograd.set_detect_anomaly(True)) (0) | 2022.01.29 |
[PyTorch/Tensorflow v1, v2] Gradient Clipping ์ถ๊ฐํ๊ธฐ (0) | 2022.01.12 |
[PyTorch] Scheduler ์๊ฐํํ๊ธฐ (Visualize scheduler) (2) | 2021.11.24 |
[PyTorch] ReduceLROnPlateau (0) | 2021.10.26 |