PyTorch Lightning provides sophisticated checkpointing capabilities that go beyond simple model state saving. When you use
torch.save()
directly, you miss out on several important features and can encounter problems in distributed training scenarios.
Lightning’s checkpointing system handles the complete training state, including optimizer states, learning rate schedulers, epoch counters, and
random number generator states. This ensures that training can be resumed exactly where it left off. In distributed training setups, especially with
strategies like FSDP (Fully Sharded Data Parallel), manual checkpointing with torch.save()
can fail to capture the distributed state
correctly, leading to training inconsistencies or failures when resuming.
The built-in checkpointing also provides advanced features like automatic saving based on metrics (save the best model based on validation loss),
keeping only the top-k checkpoints to save disk space, and configurable saving intervals. These features help manage storage efficiently and ensure
you always have access to your best-performing models.
Additionally, Lightning’s checkpointing integrates seamlessly with the training loop and callbacks, making it easier to implement complex
checkpointing strategies without interfering with the training process.
What is the potential impact?
Using manual checkpointing can lead to incomplete state saving, making it impossible to properly resume training. In distributed training
scenarios, this can cause training failures or inconsistent model states across processes. You may also lose important training metadata like
optimizer states and learning rate schedules, forcing you to restart training from scratch rather than resuming from a checkpoint.
How to fix in PyTorch Lightning?
Replace manual torch.save()
calls with Lightning’s trainer.save_checkpoint()
method. This ensures proper handling of the
complete training state including distributed training scenarios.
Non-compliant code example
import pytorch_lightning as pl
import torch
class MyModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Manual checkpoint saving
if batch_idx % 100 == 0:
torch.save(self.state_dict(), 'checkpoint.pth') # Noncompliant
return loss
Compliant code example
import pytorch_lightning as pl
class MyModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
# Lightning handles checkpointing automatically
return loss
# Use trainer's checkpointing method when needed
trainer = pl.Trainer()
trainer.fit(model)
trainer.save_checkpoint("path/to/checkpoint/file")
Use the ModelCheckpoint callback for automatic checkpoint management. This provides advanced features like metric-based saving, top-k checkpoint
retention, and configurable saving intervals.
Non-compliant code example
import pytorch_lightning as pl
import torch
# Manual checkpoint saving outside training loop
torch.save(model.state_dict(), "checkpoint.pth") # Noncompliant
trainer = pl.Trainer(max_epochs=100)
trainer.fit(model)
Compliant code example
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
# Configure automatic checkpointing
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
save_top_k=3,
mode='min',
filename='model-{epoch:02d}-{val_loss:.2f}'
)
trainer = pl.Trainer(
max_epochs=100,
callbacks=[checkpoint_callback]
)
trainer.fit(model)
Documentation