Zoom Saving/loading models and checkpointing
Last updated: January 4, 2023
Table of contents
Saving models
You can save a model by serializing its internal state dictionary. The state dictionary is a Python dictionary that contains the parameters of your model.
torch.save(model.state_dict(), "model.pth")
Loading models
To recreate your model, you first need to recreate its structure:
model = Net()
Then you can load the state dictionary containing the parameters values into it:
model.load_state_dict(torch.load("model.pth"))
Create a checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
Resume training from a checkpoint
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.train()