Skip to content

读写文件

一、为什么要学习读写文件

我们训练深度学习模型的时候,经常会遇到这些情况:

  1. 训练一个模型花了好几天,要是突然电脑断电了,之前的训练成果就全没了,所以我们需要定期把训练好的参数保存下来,相当于给模型做 “备份”。

  2. 训练好的模型要用到其他地方,比如把训练好的图像识别模型放到手机 APP 里,这时候需要把模型的参数保存下来,然后在手机里加载这些参数,让模型能正常工作。 所以读写文件就是帮我们把模型的 “数据” 和 “记忆” 存起来,之后再拿出来用的技能。

二、加载和保存张量:单个数据盒子的存取

我们可以把张量当成一个 “数据盒子”,里面装着我们需要的数据,比如一个数组。我们可以把这个盒子存到文件里,之后再从文件里拿出来。

1. 保存和读取单个张量

以 PyTorch 为例,我们可以用torch.save把张量存到文件里,用torch.load把张量读出来:

python
import torch

# 创建一个张量,相当于一个装着[0,1,2,3]的盒子
x = torch.arange(4)
# 把这个盒子存到文件里,文件名叫x-file
torch.save(x, 'x-file')
# 从文件里把盒子拿出来
x2 = torch.load('x-file')
print(x2)

运行结果是tensor([0, 1, 2, 3]),和原来的张量一样,说明我们成功把数据存起来又拿出来了。

2. 保存和读取多个张量

我们也可以把多个 “数据盒子” 放到一个文件里,比如两个张量:

python
# 再创建一个装着全0的盒子
y = torch.zeros(4)
# 把两个盒子一起存到文件里
torch.save([x, y], 'x-y-files')
# 把两个盒子一起拿出来
x2, y2 = torch.load('x-y-files')
print(x2, y2)

运行结果是tensor([0, 1, 2, 3]) tensor([0., 0., 0., 0.]),两个张量都被正确读取了。

3. 保存和读取字典形式的张量

有时候我们需要给每个 “盒子” 起个名字,比如用字典来存储,这样更方便找到我们需要的数据:

python
# 创建一个字典,给每个张量起个名字
mydict = {'x': x, 'y': y}
# 把字典存到文件里
torch.save(mydict, 'mydict')
# 把字典拿出来
mydict2 = torch.load('mydict')
print(mydict2)

运行结果是{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])},我们可以通过名字找到对应的张量,很方便。

三、加载和保存模型参数:模型记忆的存取

模型的参数就像是模型的 “记忆”,训练模型的过程就是让模型记住这些参数,这样模型才能正确预测。一个模型有很多参数,我们不可能一个个保存,所以框架给我们提供了专门的方法来保存和加载模型的所有参数。

1. 定义一个简单的模型

我们先定义一个简单的多层感知机模型,作为例子:

python
import torch.nn.functional as F
from torch import nn

# 定义一个多层感知机
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        # 隐藏层
        self.hidden = nn.Linear(20, 256)
        # 输出层
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        # 前向传播
        return self.output(F.relu(self.hidden(x)))

# 创建模型实例
net = MLP()
# 给模型输入一个随机数据,测试模型能不能正常工作
X = torch.randn(size=(2, 20))
Y = net(X)
print(Y)

2. 保存模型参数

我们可以用state_dict()方法获取模型的所有参数,这个方法会返回一个字典,里面装着模型的所有参数,然后把这个字典存到文件里:

python
# 获取模型的参数字典
params = net.state_dict()
# 把参数字典存到文件里,文件名叫mlp.params
torch.save(params, 'mlp.params')

3. 加载模型参数

要加载模型参数,我们需要先创建一个和原来一样的模型架构,然后把参数加载进去:

python
# 创建一个新的模型实例,和原来的模型架构一样
clone = MLP()
# 从文件里加载参数字典,然后放到新模型里
clone.load_state_dict(torch.load('mlp.params'))
# 把模型设置为评估模式,这样模型就不会再训练了
clone.eval()

4. 验证加载的参数是否正确

我们给新模型输入和原来一样的数据,看看输出是不是一样的:

python
# 用新模型计算
Y_clone = clone(X)
# 看看两个模型的输出是不是一样的
print(Y_clone == Y)

运行结果是全True,说明新模型的参数和原来的模型一样,加载成功了。

四、小结

  1. 我们可以用torch.savetorch.load来保存和读取张量,不管是单个张量、多个张量还是字典形式的张量都可以。

  2. 保存模型参数的时候,我们用state_dict()获取模型的参数字典,然后保存;加载的时候,先创建一样的模型架构,再用load_state_dict加载参数。

  3. 要注意,我们保存的是模型的参数,不是模型的架构,所以加载参数的时候,必须先有和原来一样的模型架构。

五、练习

  1. 即使不需要把模型部署到不同的设备上,保存模型参数还有什么好处?

    • 可以备份训练结果,防止电脑断电或者其他意外情况导致训练成果丢失;

    • 可以在之后继续训练这个模型,比如训练了 10 轮,保存了参数,之后可以从第 10 轮继续训练,不用从头开始。

  2. 假设我们只想复用网络的一部分,比如想在一个新的网络里使用之前网络的前两层,该怎么做?

    • 我们可以把原来模型的前两层提取出来,作为新模型的一部分,比如把原来模型的hidden层拿出来,放到新模型里用。
  3. 如何同时保存网络架构和参数?需要对架构加上什么限制?

    • 可以用 Python 的pickle模块来保存整个模型,包括架构和参数,但是要注意模型的架构不能有太复杂的自定义代码,否则可能无法正确保存和加载。

(注:文档部分内容可能由 AI 生成) 源地址

京ICP备2024093538号-1