Zoom Saving/loading models and checkpointing

Last updated: January 4, 2023

 Table of contents

Saving models

You can save a model by serializing its internal state dictionary. The state dictionary is a Python dictionary that contains the parameters of your model.

torch.save(model.state_dict(), "model.pth")

Loading models

To recreate your model, you first need to recreate its structure:

model = Net()

Then you can load the state dictionary containing the parameters values into it:

model.load_state_dict(torch.load("model.pth"))

Create a checkpoint

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

Resume training from a checkpoint

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.train()

Comments & questions