pytorch中保存和加载训练模型

source from https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch

The question is what is the best way to save a trained model in PyTorch. The best answer is as follows:

It depends what you want to do.

Case # 1: Save the model to use it yourself for inference

You save the model, you restore it, and then you change the model to evaluation mode. This is done because you usually have BatchNorm and Dropout layers that by default are in train mode on construction:

1
2
3
4
5
6
torch.save(model.state_dict(), filepath)


#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Case # 2: Save model to resume training later

If you need to keep training the model that you are about to save, you need to save more than just the model. You also need to save the state of the optimizer, epochs, score, etc. You would do it like this:

1
2
3
4
5
6
7
state = {
'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
...
}
torch.save(state, filepath)

To resume training you would do things like: state = torch.load(filepath), and then, to restore the state of each individual object, something like this:

1
2
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(stata['optimizer'])

Since you are resuming training, DO NOT call model.eval()once you restore the states when loading.

Case # 3: Model to be used by someone else with no access to your code

In Tensorflow you can create a .pb file that defines both the architecture and the weights of the model. This is very handy, specially when using Tensorflow serve. The equivalent way to do this in Pytorch would be:

1
2
3
4
torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

This way is still not bullet proof and since pytorch is still undergoing a lot of changes, I wouldn’t recommend it.