Skip to content

延后初始化

一、什么是延后初始化

我们平时定义深度学习模型的时候,可能会有一个疑问:我只说了每个层有多少个输出神经元,没说输入是多大的,框架怎么知道参数的大小呢?比如我定义一个全连接层有 256 个输出,但是没说输入是多少维的,这个层的权重矩阵应该是几行几列?

这就用到了延后初始化(也叫延迟初始化),简单来说就是:框架一开始不会着急初始化模型的参数,而是等第一次把数据传入模型的时候,再根据数据的形状自动推断每个层的参数形状,然后完成初始化。这样我们定义模型的时候就不用手动指定每个层的输入维度了,大大简化了模型的定义和修改,尤其是在卷积神经网络里,输入图像的分辨率会影响后续层的维度,有了这个功能就方便多了。

二、实例化网络:刚开始参数是 “未知” 的

我们先定义一个简单的多层感知机,看看刚开始的时候参数是什么样子的。

代码示例(以 PyTorch 为例,其他框架类似)

python
import torch
from torch import nn

class SmartMLP(nn.Module):
    def __init__(self, hidden=256, out=10):
        super().__init__()
        self.hidden = hidden
        self.out = out
        self.net = None  # 延后初始化,先不定义具体层
        
    def _build(self, in_dim): 
        return nn.Sequential(
            nn.Linear(in_dim, self.hidden),
            nn.ReLU(),
            nn.Linear(self.hidden, self.out)
        )
        
    def forward(self, x):
        if self.net is None:
            self.net = self._build(x.size(1)).to(x.device)
        return self.net(x)

# 使用示例
net = SmartMLP()  # 完美替代原始函数

这时候我们看看模型的参数,会发现参数的形状是未知的,比如第一个全连接层的权重形状是(256, -1),这里的-1就表示输入维度还不知道,框架还没初始化这个参数。就算我们调用初始化函数,也不会真正初始化参数:

python
# 尝试初始化参数
net.initialize()
# 再看参数,还是未知的
print(net.collect_params())

输出里的参数形状还是带-1的,说明框架只是记了要初始化参数,但还没确定参数的大小。

三、传入数据:第一次跑模型的时候自动初始化

当我们第一次把数据传入模型的时候,框架就会根据数据的形状自动推断每个层的参数形状,然后完成初始化。比如我们传入一个 2×20 的随机数据(2 个样本,每个样本 20 维):

python
# 生成一个2×20的随机数据
X = torch.rand(size=(2, 20))
# 把数据传入模型
net(X)
# 现在再看参数的形状
print(net.collect_params())

这时候我们会发现,参数的形状已经确定了:第一个全连接层的权重变成了(256, 20)(因为输入是 20 维,输出是 256 维,权重矩阵就是 256 行 20 列),第二个全连接层的权重变成了(10, 256)(输入是 256 维,输出是 10 维),框架已经自动完成了参数的初始化。

四、小结

  1. 延后初始化就是框架在第一次传入数据的时候,才自动推断参数的形状并完成初始化,这样我们定义模型的时候不用手动指定输入维度,简化了模型的定义和修改。

  2. 刚开始实例化模型的时候,参数的形状是未知的,就算调用初始化函数也不会真正初始化。

  3. 只有当第一次把数据传入模型的时候,框架才会根据数据的形状确定所有参数的形状,然后完成初始化。

(注:文档部分内容可能由 AI 生成) 源地址

京ICP备2024093538号-1