Appearance
风格迁移
一、什么是风格迁移
之前我们学的图像分类、目标检测都是对图像进行识别,而风格迁移是一种有趣的任务,它可以把一张图像的风格应用到另一张图像上,生成一张新的图像。比如我们有一张风景照片(内容图像)和一幅油画(风格图像),我们可以把油画的风格(笔触、颜色、纹理)应用到风景照片上,生成一张看起来像油画的风景照片。
风格迁移的应用场景很多,比如可以把普通照片变成名画风格,或者给照片添加不同的艺术效果,让照片更有艺术感。
二、风格迁移的原理
风格迁移的原理很简单,主要分三步:
抽取图像特征:我们用预训练好的卷积神经网络(比如 VGG-19)来抽取图像的特征。这个网络可以抽取图像的内容特征和风格特征:
内容特征:代表图像里的物体、场景等内容信息,比如风景照片里的山、树、湖等。
风格特征:代表图像的风格信息,比如油画的笔触、颜色、纹理等。
定义损失函数:我们需要定义三个损失函数:
内容损失:让合成图像和内容图像的内容特征尽可能接近,这样合成图像就能保留内容图像的内容。
风格损失:让合成图像和风格图像的风格特征尽可能接近,这样合成图像就能获得风格图像的风格。
全变分损失:用来减少合成图像里的噪点,让合成图像更平滑。
训练合成图像:我们把合成图像当作模型参数,不断更新合成图像,让损失函数尽可能小,这样就能得到一张既有内容图像的内容,又有风格图像的风格的合成图像。
简单来说,风格迁移就是通过调整合成图像,让它在内容上接近内容图像,在风格上接近风格图像,同时让图像更平滑。
代码实现(PyTorch 版)
我们可以用 PyTorch 来实现风格迁移:
python
import torch
import torchvision
from d2l import torch as d2l
# 加载预训练的VGG-19模型
pretrained_net = torchvision.models.vgg19(pretrained=True)
# 选择内容层和风格层
content_layers = [0]
style_layers = [0, 5, 10, 19, 28]
# 创建新的网络,只保留需要用到的层
net = torch.nn.Sequential(*[pretrained_net.features[i] for i in range(max(content_layers + style_layers) + 1)])三、抽取图像特征
我们用预训练好的 VGG-19 模型来抽取图像的内容特征和风格特征。VGG-19 模型有 5 个卷积块,我们选择不同的层来抽取内容特征和风格特征:
内容层:我们选择靠近输出的层(比如第 0 层)来抽取内容特征,这样可以避免合成图像过多保留内容图像的细节。
风格层:我们选择不同的层(比如第 0、5、10、19、28 层)来抽取风格特征,这样可以匹配局部和全局的风格。
定义抽取特征的函数
python
def extract_features(X, content_layers, style_layers):
contents = []
styles = []
for i in range(len(net)):
X = net[i](X)
if i in content_layers:
contents.append(X)
if i in style_layers:
styles.append(X)
return contents, styles
# 抽取内容图像的内容特征
content_img = d2l.Image.open('../img/rainier.jpg')
content_X = d2l.preprocess(content_img, (224, 224)).unsqueeze(0)
contents_Y, _ = extract_features(content_X, content_layers, style_layers)
# 抽取风格图像的风格特征
style_img = d2l.Image.open('../img/autumn_oak.jpg')
style_X = d2l.preprocess(style_img, (224, 224)).unsqueeze(0)
_, styles_Y = extract_features(style_X, content_layers, style_layers)四、定义损失函数
1. 内容损失
内容损失用来衡量合成图像和内容图像的内容特征的差异,我们用平方误差函数来计算:
python
def content_loss(Y_hat, Y):
# 计算平方误差
return torch.square(Y_hat - Y.detach()).mean()2. 风格损失
风格损失用来衡量合成图像和风格图像的风格特征的差异,我们用格拉姆矩阵来表达风格特征,然后计算平方误差:
python
def gram(X):
# 计算格拉姆矩阵
num_channels, n = X.shape[1], X.numel() // X.shape[1]
X = X.reshape((num_channels, n))
return torch.matmul(X, X.T) / (num_channels * n)
def style_loss(Y_hat, gram_Y):
# 计算风格损失
return torch.square(gram(Y_hat) - gram_Y.detach()).mean()
# 预先计算风格图像的格拉姆矩阵
styles_Y_gram = [gram(Y) for Y in styles_Y]3. 全变分损失
全变分损失用来减少合成图像里的噪点,让合成图像更平滑:
python
def tv_loss(Y_hat):
# 计算全变分损失
return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())4. 总损失函数
总损失函数是内容损失、风格损失和全变分损失的加权和:
python
def total_loss(Y_hat, contents_Y, styles_Y_gram, content_weight, style_weight, tv_weight):
# 计算内容损失
content_l = sum(content_loss(Y_hat[0], Y) for Y in contents_Y)
# 计算风格损失
style_l = sum(style_loss(Y_hat[1][i], styles_Y_gram[i]) for i in range(len(styles_Y_gram)))
# 计算全变分损失
tv_l = tv_loss(Y_hat[0])
# 总损失
return content_weight * content_l + style_weight * style_l + tv_weight * tv_l五、训练合成图像
我们把合成图像当作模型参数,不断更新合成图像,让总损失函数尽可能小。
初始化合成图像
我们可以把合成图像初始化为内容图像,这样合成图像一开始就有内容图像的内容:
python
class SynthesizedImage(torch.nn.Module):
def __init__(self, img_shape, **kwargs):
super(SynthesizedImage, self).__init__(**kwargs)
# 初始化合成图像为内容图像
self.weight = torch.nn.Parameter(content_X.clone())
def forward(self):
return self.weight
# 创建合成图像模型
net = SynthesizedImage(content_X.shape)训练模型
python
# 定义优化器
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
# 定义权重
content_weight, style_weight, tv_weight = 1, 1e3, 10
# 训练
num_epochs = 300
for epoch in range(num_epochs):
optimizer.zero_grad()
# 抽取合成图像的特征
contents_Y_hat, styles_Y_hat = extract_features(net(), content_layers, style_layers)
# 计算损失
l = total_loss((contents_Y_hat, styles_Y_hat), contents_Y, styles_Y_gram, content_weight, style_weight, tv_weight)
l.backward()
optimizer.step()
# 打印损失
if (epoch + 1) % 50 == 0:
print(f'epoch {epoch+1}, loss {l.item():.4f}')六、可视化合成图像
训练完成后,我们可以把合成图像可视化出来:
python
# 把合成图像转换成图片
synthesized_img = d2l.postprocess(net().detach())
# 显示图片
d2l.plt.imshow(synthesized_img)
d2l.plt.title('合成图像')
d2l.plt.show()这样我们就能看到合成图像,它既有内容图像的内容,又有风格图像的风格。
七、小结
风格迁移是一种有趣的任务,它可以把一张图像的风格应用到另一张图像上,生成一张新的图像。
风格迁移的原理是通过调整合成图像,让它在内容上接近内容图像,在风格上接近风格图像,同时让图像更平滑。
我们用预训练好的卷积神经网络来抽取图像的内容特征和风格特征,用格拉姆矩阵来表达风格特征。
训练合成图像的时候,我们把合成图像当作模型参数,不断更新合成图像,让总损失函数尽可能小。
(注:文档部分内容可能由 AI 生成) 源地址