๐Ÿ Python & library/PyTorch

[PyTorch] Scheduler ์‹œ๊ฐํ™”ํ•˜๊ธฐ (Visualize scheduler)

๋ณต๋งŒ 2021. 11. 24. 17:05

๋‹ค์Œ ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•ด PyTorch scheduler์„ ์‹œ๊ฐํ™”ํ•  ์ˆ˜ ์žˆ๋‹ค.

import matplotlib.pyplot as plt

def visualize_scheduler(optimizer, scheduler, epochs):
    lrs = []
    for _ in range(epochs):
        optimizer.step()
        lrs.append(optimizer.param_groups[0]['lr'])
        scheduler.step()

    plt.plot(lrs)
    plt.show()

 

scheduler.get_lr()๋กœ learning rate๋ฅผ ์–ป์–ด์˜ค์ง€ ์•Š๊ณ 

optimizer.param_groups[0]['lr']๋กœ ์–ป์–ด์˜ค๋Š” ์ด์œ ๋Š”,

ReduceLROnPlateau ๋“ฑ์˜ scheduler์˜ ๊ฒฝ์šฐ get_lr() method๊ฐ€ ์—†๊ธฐ ๋•Œ๋ฌธ.

 

 

๋‹ค์Œ๊ณผ ๊ฐ™์ด ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

import torch
import torch.optim as optim

epochs = 300
optimizer = optim.SGD([torch.tensor(1)], lr=0.1, momentum=0.9)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=0)

visualize_scheduler(optimizer, scheduler, epochs)

๋ฐ˜์‘ํ˜•