🐍 Python & library/PyTorch Lightning

[PyTorch Lightning] checkpoint 저장하기

복만 2023. 2. 7. 14:05

기본편 - 자동 저장

 

 

Saving and loading checkpoints (basic) — PyTorch Lightning 1.9.0 documentation

Shortcuts

pytorch-lightning.readthedocs.io

 

PyTorch Lightning의 Trainer을 이용해 학습을 진행하면, 자동으로 가장 마지막 training epoch의 checkpoint를 저장해준다.

 

trainer = Trainer()

 

만약 checkpoint가 저장되는 위치를 바꾸고 싶다면 다음과 같이 지정해줄 수 있다.

 

trainer = Trainer(default_root_dir='path/to/')

 

혹은 별도로 checkpoint 저장을 하지 않으려면 다음과 같이 지정하면 된다.

 

trainer = Trainer(enable_checkpointing=False)

 

 

 

심화편 1 - callback 이용

 

 

Customize checkpointing behavior (intermediate) — PyTorch Lightning 1.9.0 documentation

Customize checkpointing behavior (intermediate) Audience: Users looking to customize the checkpointing behavior Modify checkpointing behavior For fine-grained control over checkpointing behavior, use the ModelCheckpoint object from pytorch_lightning.callba

pytorch-lightning.readthedocs.io

 

가장 마지막 epoch가 아니라 매 epoch마다 checkpoint를 저장하는 등,

checkpointing을 좀더 세부적으로 설정하고 싶다면 ModelCheckpoint object를 생성해서 인자로 넘겨줄 수 있다.

 

몇가지 예시는 다음과 같다.

 

from pytorch_lightning.callbacks import ModelCheckpoint


# "val_loss" metric이 높은 상위 10개 checkpoint를 저장
checkpoint_callback = ModelCheckpoint(
    save_top_k=10,
    monitor="val_loss",
    mode="min",
    dirpath="my/path/",
    filename="sample-mnist-{epoch:02d}-{val_loss:.2f}",
)

# "global_step" 기준으로 마지막 10개의 checkpoint만 저장
# make sure you log it inside your LightningModule
checkpoint_callback = ModelCheckpoint(
    save_top_k=10,
    monitor="global_step",
    mode="max",
    dirpath="my/path/",
    filename="sample-mnist-{epoch:02d}-{global_step}",
)

 

이렇게 만든 ModelCheckpoint object는 Trainer에 callback으로 넘겨주면 된다.

 

trainer = Trainer(callbacks=[checkpoint_callback])

 

 

ModelCheckpoint object의 주요 인자는 다음과 같다.

  • monitor: checkpoint 저장의 기준이 되는 metric
  • save_top_k: 저장할 checkpoint의 수
  • save_weights_only (bool): True로 설정할 시 model weight만 저장함 (optimizer state 등은 제외)
  • dirpath: checkpoint가 저장될 위치
  • filename: checkpoint 파일이름 형식

 

 

 

심화편 2 - 수동으로 저장

 

 

Customize checkpointing behavior (intermediate) — PyTorch Lightning 1.9.0 documentation

Customize checkpointing behavior (intermediate) Audience: Users looking to customize the checkpointing behavior Modify checkpointing behavior For fine-grained control over checkpointing behavior, use the ModelCheckpoint object from pytorch_lightning.callba

pytorch-lightning.readthedocs.io

 

혹은 원하는 위치에서 Trainer object의 save_checkpoint() method를 호출해 checkpoint를 저장할 수 있다.

 

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")

 

반응형