Skip to content

自然语言推断(BERT版)

一、概述

之前我们学过用注意力模型来完成自然语言推断任务,这次我们换用微调 BERT 的方式来重新做这个任务。自然语言推断的核心是判断 “假设” 能不能从 “前提” 里推断出来,一共有三种关系:蕴涵(假设能从前提里推出来)、矛盾(假设和前提完全相反)、中性(没法从前提里判断假设的对错)。

BERT 是一个预训练好的语言模型,已经在海量文本里学会了很多语言的语义信息,微调它可以让它快速适配自然语言推断这个任务,而且效果会比从零训练的模型更好。我们这次用的还是 SNLI 数据集,这个数据集有 50 多万个带标签的英语句子对,正好适合用来做这个任务。

二、加载预训练的 BERT

预训练的 BERT 有两个版本:

  1. bert.base:和原始的 BERT 基础模型一样大,有 1.1 亿个参数,微调的时候需要很多计算资源。

  2. bert.small:是小版本的 BERT,参数更少,适合我们用来演示。

我们用小版本的 BERT 来做实验,首先需要下载预训练的参数和词表,然后用代码加载这个预训练模型:

python
import os
import torch
from d2l import torch as d2l

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
                         num_heads, num_layers, dropout, max_len, devices):
    # 下载并解压预训练模型
    data_dir = d2l.download_extract(pretrained_model)
    # 加载词表
    vocab = d2l.Vocab()
    vocab.token_to_idx = {token: idx for idx, token in enumerate(
        d2l.read_file(os.path.join(data_dir, 'vocab.txt')))}
    # 定义BERT模型
    bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens,
                         num_heads, num_layers, dropout, max_len)
    # 加载预训练的参数
    bert.load_state_dict(torch.load(os.path.join(data_dir,
                                                'pretrained.params')))
    return bert, vocab

# 加载小版本的BERT
devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
    'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
    num_layers=2, dropout=0.1, max_len=512, devices=devices)

三、微调 BERT 的数据集

我们需要把 SNLI 数据集转换成 BERT 能处理的格式,所以定义了一个SNLIBERTDataset类,它会把前提和假设打包成 BERT 的输入序列:

1. 数据处理步骤

  1. 首先把前提和假设的句子转换成小写,然后拆分成词元。

  2. 因为 BERT 的输入有最大长度限制,所以如果两个句子加起来太长,就把长的句子的最后一个词元删掉,直到满足长度要求(要给<cls><sep>这些特殊词元留位置)。

  3. 然后生成 BERT 的输入 token,还有片段索引(用来区分前提和假设,前提的片段索引是 0,假设的是 1),最后把这些转换成模型需要的格式,比如 token 的 id、片段索引、有效长度。

2. 数据集代码实现

python
class SNLIBERTDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, max_len, vocab):
        # 读取前提、假设和标签
        all_premise_hypothesis_tokens = [
            (p_tokens, h_tokens) for p_tokens, h_tokens in zip(
                *[d2l.tokenize([s.lower() for s in sentences])
                  for sentences in dataset[:2]])]
        self.labels = torch.tensor(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        # 预处理数据
        self.all_token_ids, self.all_segments, self.valid_lens = self._preprocess(
            all_premise_hypothesis_tokens)
        print(f'读取了 {len(self.all_token_ids)} 个样本')

    def _preprocess(self, all_premise_hypothesis_tokens):
        out = []
        for p_tokens, h_tokens in all_premise_hypothesis_tokens:
            # 截断句子
            self._truncate_pair_of_tokens(p_tokens, h_tokens)
            # 生成BERT的输入token和片段索引
            tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
            # 转换成token id,填充到最大长度
            token_ids = self.vocab[tokens] + [self.vocab['<pad>']] * (
                self.max_len - len(tokens))
            segments = segments + [0] * (self.max_len - len(segments))
            valid_len = len(tokens)
            out.append((token_ids, segments, valid_len))
        # 转换成张量
        return (torch.tensor([token_ids for token_ids, _, _ in out]),
                torch.tensor([segments for _, segments, _ in out]),
                torch.tensor([valid_len for _, _, valid_len in out]))

    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
        # 给特殊词元留位置,所以两个句子的总长度不能超过max_len-3
        while len(p_tokens) + len(h_tokens) > self.max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx]), self.labels[idx]

    def __len__(self):
        return len(self.all_token_ids)

# 加载数据集
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
# 创建数据迭代器
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
                                        num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                       num_workers=num_workers)

四、微调 BERT

我们的模型是 BERT 加上一个简单的多层感知机,这个多层感知机用来把 BERT 的输出转换成自然语言推断的三个结果:

1. 模型结构

python
class BERTClassifier(torch.nn.Module):
    def __init__(self, bert):
        super(BERTClassifier, self).__init__()
        self.encoder = bert.encoder  # BERT的编码器
        self.hidden = bert.hidden    # BERT的隐藏层
        self.output = torch.nn.Linear(256, 3)  # 输出层,输出3个类别

    def forward(self, inputs):
        tokens_X, segments_X, valid_lens_x = inputs
        # 用BERT编码输入
        encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
        # 取<cls>词元的表示,它包含了整个输入的信息
        return self.output(self.hidden(encoded_X[:, 0, :]))

# 创建模型
net = BERTClassifier(bert)

2. 训练模型

训练的时候,我们只从零开始训练输出层的参数,BERT 的编码器和隐藏层的参数是微调的(在原来的基础上稍微调整)。我们用 Adam 优化器,学习率设为 1e-4,训练 5 个轮次:

python
lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = torch.nn.CrossEntropyLoss(reduction='none')
# 训练模型
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

训练完成后,结果大概是:训练准确率 0.79 左右,测试准确率 0.78 左右,用两个 GPU 的话每秒能处理 1 万多个样本。

五、小结

  1. 我们可以微调预训练的 BERT 来做自然语言推断任务,BERT 已经学会了很多语言的语义信息,微调后能快速适配这个任务,效果不错。

  2. 微调的时候,BERT 作为模型的一部分,只有和预训练相关的参数(比如用来计算遮蔽语言模型损失和下一句预测损失的参数)不会更新,其他参数都会微调。

  3. 我们需要把数据集转换成 BERT 能处理的格式,包括截断句子、生成 token id 和片段索引等。

(注:文档部分内容可能由 AI 生成) 源地址

京ICP备2024093538号-1