Skip to content

图像增强

一、什么是图像增强

图像增强(图像增广)是在训练图像上做一系列随机变化,生成相似但不同的训练样本,从而扩大训练集的规模。

图像增强目的:

  1. 减少模型对某些属性的依赖,提高模型的泛化能力,比如让模型不那么依赖物体的位置、颜色这些属性。
  2. 解决数据集不够大的问题,让模型能学到更多的特征,避免过拟合

比如 AlexNet 的成功就离不开图像增强技术,它通过随机裁剪、翻转等方式生成了更多的训练样本,让模型的泛化能力更好。

二、常用的图像增强方法

我们用一张猫咪的图片作为例子,看看各种图像增强的效果。

!alt text(image.png)

1. 翻转和裁剪

(1)左右翻转

alt text

左右翻转图像通常不会改变对象的类别,是最常用的图像增强方法之一。比如我们可以让图像有 50% 的概率左右翻转,这样可以减少模型对物体左右位置的依赖。

在 PyTorch 里,我们可以用torchvision.transforms.RandomHorizontalFlip()来实现:

python
import torchvision.transforms as transforms
from d2l import torch as d2l
from PIL import Image

# 加载图片
img = Image.open('../img/cat1.jpg')
# 定义左右翻转的增强方法
aug = transforms.RandomHorizontalFlip()
# 展示效果
d2l.apply(img, aug)

(2)上下翻转

上下翻转不如左右翻转常用,但对于一些图像来说,上下翻转也不会影响识别。我们可以用torchvision.transforms.RandomVerticalFlip()来实现,让图像有 50% 的概率上下翻转。

(3)随机裁剪

alt text

随机裁剪可以让物体以不同的比例出现在图像的不同位置,减少模型对物体位置的依赖。比如我们可以随机裁剪一个面积为原始面积 10% 到 100% 的区域,然后把这个区域缩放到 200×200 像素。

在 PyTorch 里,我们可以用torchvision.transforms.RandomResizedCrop()来实现:

python
shape_aug = transforms.RandomResizedCrop((200, 200), scale=(0.1, 1), ratio=(0.5, 2))
d2l.apply(img, shape_aug)

2. 改变颜色

我们可以改变图像的亮度、对比度、饱和度和色调,降低模型对颜色的敏感度。 alt text

(1)改变亮度

我们可以随机更改图像的亮度,比如让亮度在原始图像的 50% 到 150% 之间变化:

python
brightness_aug = transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0)
d2l.apply(img, brightness_aug)

(2)改变色调

我们也可以随机更改图像的色调,让图像的颜色发生变化:

python
hue_aug = transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.5)
d2l.apply(img, hue_aug)

(3)同时改变多个颜色属性

我们可以用ColorJitter同时随机更改图像的亮度、对比度、饱和度和色调:

python
color_aug = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
d2l.apply(img, color_aug)

3. 结合多种图像增强方法

alt text

在实际使用中,我们通常会把多种图像增强方法结合起来使用,比如先左右翻转,再改变颜色,最后随机裁剪。

在 PyTorch 里,我们可以用transforms.Compose()把多个增强方法组合起来:

python
augs = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    color_aug,
    shape_aug
])
d2l.apply(img, augs)

三、使用图像增强进行训练

我们用 CIFAR-10 数据集来演示如何用图像增强训练模型。CIFAR-10 数据集里的图片颜色和大小差异比较大,很适合用图像增强。

1. 定义训练和测试的增强方法

在训练的时候,我们用随机左右翻转和ToTensor(把图片转换成模型需要的格式);在测试的时候,我们只需要ToTensor,不需要随机操作,这样才能得到确定的结果。

python
# 训练集的增强方法
train_augs = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

# 测试集的增强方法
test_augs = transforms.Compose([
    transforms.ToTensor()
])

2. 加载数据集

我们定义一个函数来加载 CIFAR-10 数据集,并应用图像增强:

python
def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(root="../data", train=is_train,
                                           transform=augs, download=True)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=is_train, num_workers=d2l.get_dataloader_workers())
    return dataloader

3. 训练模型

我们用 ResNet-18 模型来训练,使用 Adam 优化器,用多 GPU 训练:

python
import torch
import torch.nn as nn
from d2l import torch as d2l

# 定义模型
batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3)

# 初始化模型参数
def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)

# 训练函数
def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    trainer = torch.optim.Adam(net.parameters(), lr=lr)
    d2l.train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)

# 开始训练
train_with_data_aug(train_augs, test_augs, net)

训练完成后,我们可以看到使用图像增强的模型测试精度比不使用的要高,说明图像增强确实能提高模型的泛化能力。

四、小结

  1. 图像增强是通过随机改变训练图像来生成更多的训练样本,提高模型的泛化能力,避免过拟合

  2. 常用的图像增强方法有翻转、裁剪、改变颜色等,我们可以把这些方法结合起来使用。

  3. 在训练的时候使用图像增强,测试的时候不使用随机的图像增强,这样才能得到确定的结果。

  4. 图像增强在计算机视觉里应用很广泛,尤其是在数据集不够大的时候,能有效提升模型的性能。

(注:文档部分内容可能由 AI 生成) 源地址

京ICP备2024093538号-1