๐Ÿ Python & library/PyTorch

[PyTorch] Livelossplot ์‚ฌ์šฉ์˜ˆ์ œ

๋ณต๋งŒ 2021. 4. 3. 22:34

www.kaggle.com/pmigdal/livelossplot-for-training-loss-tracking

 

Livelossplot์€ ํ•™์Šต ๊ณผ์ •์—์„œ real-time์œผ๋กœ ์‹คํ–‰ ๋กœ๊ทธ๋ฅผ ๋ณด์—ฌ์ฃผ๋Š” ํŒจํ‚ค์ง€์ด๋‹ค.

 

๋‹ค์Œ๊ณผ ๊ฐ™์ด ์„ค์น˜ํ•  ์ˆ˜ ์žˆ๋‹ค.

!pip install livelossplot

 

์•„๋ž˜๋Š” ์‚ฌ์šฉ ์˜ˆ์ œ์ด๋‹ค.

from livelossplot import PlotLosses

liveloss = PlotLosses()
model = model.to(device)
loss_F = nn.MSELoss()
    
for epoch in range(num_epochs):
    logs = {}
    for phase in ['train', 'validation']:
        if phase == 'train': model.train()
        else: model.eval()

        running_loss = 0.0

        for x, y in dataloaders[phase]:
            x, y = x.to(device), y.to(device)

            y_pred = model(x)
            loss = loss_f(y_pred, y)

            if phase == 'train':
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            running_loss += loss.item() * inputs.size(0)

        epoch_loss = running_loss / len(dataloaders[phase].dataset)
            
        prefix = ''
        if phase == 'validation':
            prefix = 'val_'

        logs[prefix + 'loss'] = np.sqrt(epoch_loss)
        
    liveloss.update(logs)
    liveloss.draw()
๋ฐ˜์‘ํ˜•