๐ Python & library/PyTorch
[PyTorch] Livelossplot ์ฌ์ฉ์์
๋ณต๋ง
2021. 4. 3. 22:34
www.kaggle.com/pmigdal/livelossplot-for-training-loss-tracking
Livelossplot์ ํ์ต ๊ณผ์ ์์ real-time์ผ๋ก ์คํ ๋ก๊ทธ๋ฅผ ๋ณด์ฌ์ฃผ๋ ํจํค์ง์ด๋ค.
๋ค์๊ณผ ๊ฐ์ด ์ค์นํ ์ ์๋ค.
!pip install livelossplot
์๋๋ ์ฌ์ฉ ์์ ์ด๋ค.
from livelossplot import PlotLosses
liveloss = PlotLosses()
model = model.to(device)
loss_F = nn.MSELoss()
for epoch in range(num_epochs):
logs = {}
for phase in ['train', 'validation']:
if phase == 'train': model.train()
else: model.eval()
running_loss = 0.0
for x, y in dataloaders[phase]:
x, y = x.to(device), y.to(device)
y_pred = model(x)
loss = loss_f(y_pred, y)
if phase == 'train':
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
prefix = ''
if phase == 'validation':
prefix = 'val_'
logs[prefix + 'loss'] = np.sqrt(epoch_loss)
liveloss.update(logs)
liveloss.draw()
๋ฐ์ํ