🐍 Python & library/PyTorch

[PyTorch] Scheduler 시각화하기 (Visualize scheduler)

복만 2021. 11. 24. 17:05

다음 함수를 이용해 PyTorch scheduler을 시각화할 수 있다.

import matplotlib.pyplot as plt

def visualize_scheduler(optimizer, scheduler, epochs):
    lrs = []
    for _ in range(epochs):
        optimizer.step()
        lrs.append(optimizer.param_groups[0]['lr'])
        scheduler.step()

    plt.plot(lrs)
    plt.show()

 

scheduler.get_lr()로 learning rate를 얻어오지 않고

optimizer.param_groups[0]['lr']로 얻어오는 이유는,

ReduceLROnPlateau 등의 scheduler의 경우 get_lr() method가 없기 때문.

 

 

다음과 같이 사용할 수 있다.

import torch
import torch.optim as optim

epochs = 300
optimizer = optim.SGD([torch.tensor(1)], lr=0.1, momentum=0.9)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0)

visualize_scheduler(optimizer, scheduler, epochs)

반응형