기본편 - 자동 저장
PyTorch Lightning의 Trainer을 이용해 학습을 진행하면, 자동으로 가장 마지막 training epoch의 checkpoint를 저장해준다.
trainer = Trainer()
만약 checkpoint가 저장되는 위치를 바꾸고 싶다면 다음과 같이 지정해줄 수 있다.
trainer = Trainer(default_root_dir='path/to/')
혹은 별도로 checkpoint 저장을 하지 않으려면 다음과 같이 지정하면 된다.
trainer = Trainer(enable_checkpointing=False)
심화편 1 - callback 이용
가장 마지막 epoch가 아니라 매 epoch마다 checkpoint를 저장하는 등,
checkpointing을 좀더 세부적으로 설정하고 싶다면 ModelCheckpoint object를 생성해서 인자로 넘겨줄 수 있다.
몇가지 예시는 다음과 같다.
from pytorch_lightning.callbacks import ModelCheckpoint
# "val_loss" metric이 높은 상위 10개 checkpoint를 저장
checkpoint_callback = ModelCheckpoint(
save_top_k=10,
monitor="val_loss",
mode="min",
dirpath="my/path/",
filename="sample-mnist-{epoch:02d}-{val_loss:.2f}",
)
# "global_step" 기준으로 마지막 10개의 checkpoint만 저장
# make sure you log it inside your LightningModule
checkpoint_callback = ModelCheckpoint(
save_top_k=10,
monitor="global_step",
mode="max",
dirpath="my/path/",
filename="sample-mnist-{epoch:02d}-{global_step}",
)
이렇게 만든 ModelCheckpoint object는 Trainer에 callback으로 넘겨주면 된다.
trainer = Trainer(callbacks=[checkpoint_callback])
ModelCheckpoint object의 주요 인자는 다음과 같다.
- monitor: checkpoint 저장의 기준이 되는 metric
- save_top_k: 저장할 checkpoint의 수
- save_weights_only (bool): True로 설정할 시 model weight만 저장함 (optimizer state 등은 제외)
- dirpath: checkpoint가 저장될 위치
- filename: checkpoint 파일이름 형식
심화편 2 - 수동으로 저장
혹은 원하는 위치에서 Trainer object의 save_checkpoint() method를 호출해 checkpoint를 저장할 수 있다.
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
반응형
'🐍 Python & library > PyTorch Lightning' 카테고리의 다른 글
[PyTorch Lightning] 로그 기록, Tensorboard로 Logging하기 (0) | 2023.01.09 |
---|