Appearance
读写文件
一、为什么要学习读写文件
我们训练深度学习模型的时候,经常会遇到这些情况:
训练一个模型花了好几天,要是突然电脑断电了,之前的训练成果就全没了,所以我们需要定期把训练好的参数保存下来,相当于给模型做 “备份”。
训练好的模型要用到其他地方,比如把训练好的图像识别模型放到手机 APP 里,这时候需要把模型的参数保存下来,然后在手机里加载这些参数,让模型能正常工作。 所以读写文件就是帮我们把模型的 “数据” 和 “记忆” 存起来,之后再拿出来用的技能。
二、加载和保存张量:单个数据盒子的存取
我们可以把张量当成一个 “数据盒子”,里面装着我们需要的数据,比如一个数组。我们可以把这个盒子存到文件里,之后再从文件里拿出来。
1. 保存和读取单个张量
以 PyTorch 为例,我们可以用torch.save把张量存到文件里,用torch.load把张量读出来:
python
import torch
# 创建一个张量,相当于一个装着[0,1,2,3]的盒子
x = torch.arange(4)
# 把这个盒子存到文件里,文件名叫x-file
torch.save(x, 'x-file')
# 从文件里把盒子拿出来
x2 = torch.load('x-file')
print(x2)运行结果是tensor([0, 1, 2, 3]),和原来的张量一样,说明我们成功把数据存起来又拿出来了。
2. 保存和读取多个张量
我们也可以把多个 “数据盒子” 放到一个文件里,比如两个张量:
python
# 再创建一个装着全0的盒子
y = torch.zeros(4)
# 把两个盒子一起存到文件里
torch.save([x, y], 'x-y-files')
# 把两个盒子一起拿出来
x2, y2 = torch.load('x-y-files')
print(x2, y2)运行结果是tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.]),两个张量都被正确读取了。
3. 保存和读取字典形式的张量
有时候我们需要给每个 “盒子” 起个名字,比如用字典来存储,这样更方便找到我们需要的数据:
python
# 创建一个字典,给每个张量起个名字
mydict = {'x': x, 'y': y}
# 把字典存到文件里
torch.save(mydict, 'mydict')
# 把字典拿出来
mydict2 = torch.load('mydict')
print(mydict2)运行结果是{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])},我们可以通过名字找到对应的张量,很方便。
三、加载和保存模型参数:模型记忆的存取
模型的参数就像是模型的 “记忆”,训练模型的过程就是让模型记住这些参数,这样模型才能正确预测。一个模型有很多参数,我们不可能一个个保存,所以框架给我们提供了专门的方法来保存和加载模型的所有参数。
1. 定义一个简单的模型
我们先定义一个简单的多层感知机模型,作为例子:
python
import torch.nn.functional as F
from torch import nn
# 定义一个多层感知机
class MLP(nn.Module):
def __init__(self):
super().__init__()
# 隐藏层
self.hidden = nn.Linear(20, 256)
# 输出层
self.output = nn.Linear(256, 10)
def forward(self, x):
# 前向传播
return self.output(F.relu(self.hidden(x)))
# 创建模型实例
net = MLP()
# 给模型输入一个随机数据,测试模型能不能正常工作
X = torch.randn(size=(2, 20))
Y = net(X)
print(Y)2. 保存模型参数
我们可以用state_dict()方法获取模型的所有参数,这个方法会返回一个字典,里面装着模型的所有参数,然后把这个字典存到文件里:
python
# 获取模型的参数字典
params = net.state_dict()
# 把参数字典存到文件里,文件名叫mlp.params
torch.save(params, 'mlp.params')3. 加载模型参数
要加载模型参数,我们需要先创建一个和原来一样的模型架构,然后把参数加载进去:
python
# 创建一个新的模型实例,和原来的模型架构一样
clone = MLP()
# 从文件里加载参数字典,然后放到新模型里
clone.load_state_dict(torch.load('mlp.params'))
# 把模型设置为评估模式,这样模型就不会再训练了
clone.eval()4. 验证加载的参数是否正确
我们给新模型输入和原来一样的数据,看看输出是不是一样的:
python
# 用新模型计算
Y_clone = clone(X)
# 看看两个模型的输出是不是一样的
print(Y_clone == Y)运行结果是全True,说明新模型的参数和原来的模型一样,加载成功了。
四、小结
我们可以用
torch.save和torch.load来保存和读取张量,不管是单个张量、多个张量还是字典形式的张量都可以。保存模型参数的时候,我们用
state_dict()获取模型的参数字典,然后保存;加载的时候,先创建一样的模型架构,再用load_state_dict加载参数。要注意,我们保存的是模型的参数,不是模型的架构,所以加载参数的时候,必须先有和原来一样的模型架构。
五、练习
即使不需要把模型部署到不同的设备上,保存模型参数还有什么好处?
可以备份训练结果,防止电脑断电或者其他意外情况导致训练成果丢失;
可以在之后继续训练这个模型,比如训练了 10 轮,保存了参数,之后可以从第 10 轮继续训练,不用从头开始。
假设我们只想复用网络的一部分,比如想在一个新的网络里使用之前网络的前两层,该怎么做?
- 我们可以把原来模型的前两层提取出来,作为新模型的一部分,比如把原来模型的
hidden层拿出来,放到新模型里用。
- 我们可以把原来模型的前两层提取出来,作为新模型的一部分,比如把原来模型的
如何同时保存网络架构和参数?需要对架构加上什么限制?
- 可以用 Python 的
pickle模块来保存整个模型,包括架构和参数,但是要注意模型的架构不能有太复杂的自定义代码,否则可能无法正确保存和加载。
- 可以用 Python 的
(注:文档部分内容可能由 AI 生成) 源地址