模型保存与加载
在PyTorch中state_dict
可以理解为Python中的字典对象,其中每一层映射到一个参数张量,只有包含待学习参数的网络层会在模型的state_dict
中存在元素值。当保存模型时,只有模型的可学习参数是有必要进行保存的。
可学习参数可以使用torch.save(model.state_dict, PATH)
进行保存,文件一般以.pt
或者.pth
为后缀。在加载时,需要使用语句model.load_state_dict(torch.load(PATH))
进行加载,在参数加载之后,需要调用model.eval()
来将模型中的dropout
和batch normalization
设置为评估模式。
这里需要注意的是model.load_state_dict()
可接受的参数是Python字典类型,而不是路径字符串,所以需要先使用tourch.load()
进行反序列化。
将优化器对象中的state_dict
与模型的state_dict
一起保存到文件,可以建立模型断点,用于推断或者恢复训练。在恢复模型训练时,需要调用model.train()
将模型中的各层都置于训练模式。
由于torch.load()
和torch.save()
采用Python的Pickle进行序列化和反序列化,所以可以利用字典将多个不同的模型或者优化器的数据保存到一个文件中。
在跨CPU和GPU设备进行参数保存和加载时,只需要将torch.device()
获取的设备类型传递给model.load()
的参数map_location
即可。当模型需要加载到GPU时,需要调用model.to()
来确保将模型转换为针对目标设备优化的模型。