Appearance
延后初始化
一、什么是延后初始化
我们平时定义深度学习模型的时候,可能会有一个疑问:我只说了每个层有多少个输出神经元,没说输入是多大的,框架怎么知道参数的大小呢?比如我定义一个全连接层有 256 个输出,但是没说输入是多少维的,这个层的权重矩阵应该是几行几列?
这就用到了延后初始化(也叫延迟初始化),简单来说就是:框架一开始不会着急初始化模型的参数,而是等第一次把数据传入模型的时候,再根据数据的形状自动推断每个层的参数形状,然后完成初始化。这样我们定义模型的时候就不用手动指定每个层的输入维度了,大大简化了模型的定义和修改,尤其是在卷积神经网络里,输入图像的分辨率会影响后续层的维度,有了这个功能就方便多了。
二、实例化网络:刚开始参数是 “未知” 的
我们先定义一个简单的多层感知机,看看刚开始的时候参数是什么样子的。
代码示例(以 PyTorch 为例,其他框架类似)
python
import torch
from torch import nn
class SmartMLP(nn.Module):
def __init__(self, hidden=256, out=10):
super().__init__()
self.hidden = hidden
self.out = out
self.net = None # 延后初始化,先不定义具体层
def _build(self, in_dim):
return nn.Sequential(
nn.Linear(in_dim, self.hidden),
nn.ReLU(),
nn.Linear(self.hidden, self.out)
)
def forward(self, x):
if self.net is None:
self.net = self._build(x.size(1)).to(x.device)
return self.net(x)
# 使用示例
net = SmartMLP() # 完美替代原始函数这时候我们看看模型的参数,会发现参数的形状是未知的,比如第一个全连接层的权重形状是(256, -1),这里的-1就表示输入维度还不知道,框架还没初始化这个参数。就算我们调用初始化函数,也不会真正初始化参数:
python
# 尝试初始化参数
net.initialize()
# 再看参数,还是未知的
print(net.collect_params())输出里的参数形状还是带-1的,说明框架只是记了要初始化参数,但还没确定参数的大小。
三、传入数据:第一次跑模型的时候自动初始化
当我们第一次把数据传入模型的时候,框架就会根据数据的形状自动推断每个层的参数形状,然后完成初始化。比如我们传入一个 2×20 的随机数据(2 个样本,每个样本 20 维):
python
# 生成一个2×20的随机数据
X = torch.rand(size=(2, 20))
# 把数据传入模型
net(X)
# 现在再看参数的形状
print(net.collect_params())这时候我们会发现,参数的形状已经确定了:第一个全连接层的权重变成了(256, 20)(因为输入是 20 维,输出是 256 维,权重矩阵就是 256 行 20 列),第二个全连接层的权重变成了(10, 256)(输入是 256 维,输出是 10 维),框架已经自动完成了参数的初始化。
四、小结
延后初始化就是框架在第一次传入数据的时候,才自动推断参数的形状并完成初始化,这样我们定义模型的时候不用手动指定输入维度,简化了模型的定义和修改。
刚开始实例化模型的时候,参数的形状是未知的,就算调用初始化函数也不会真正初始化。
只有当第一次把数据传入模型的时候,框架才会根据数据的形状确定所有参数的形状,然后完成初始化。
(注:文档部分内容可能由 AI 生成) 源地址