Skip to content

BERT数据集

一、概述

在之前的章节里我们学习了 BERT 模型,要预训练 BERT,我们需要把数据集处理成适合两个预训练任务(遮蔽语言模型和下一句预测)的格式。最初的 BERT 模型是在超大的图书语料库和英语维基百科上预训练的,但是对于我们普通学习者来说,用这么大的数据集太麻烦了,而且有时候我们需要在特定领域(比如医学)预训练 BERT,所以定制自己的数据集就很有必要。

这里我们用 WikiText-2 这个小一点的语料库来演示,它比之前用的 PTB 数据集更好用:

  1. 它保留了原来的标点符号,适合做下一句预测任务;
  2. 保留了原来的大小写和数字,信息更完整;
  3. 它的大小是 PTB 的两倍多,数据更丰富。

二、WikiText-2 数据集的读取

首先我们要读取 WikiText-2 数据集,这个数据集里每行是一个段落,我们只保留至少有两句话的段落,然后用句号作为分隔符把段落拆成句子。

代码示例(PyTorch 版)

python
import os
import random
import torch
import numpy as np
from d2l import torch as d2l

def _read_wiki(data_dir):
    # 读取训练集的文本文件
    file_name = os.path.join(data_dir, 'wiki.train.tokens')
    with open(file_name, 'r') as f:
        lines = f.readlines()
    # 把每个段落拆成句子,只保留至少有两句话的段落
    paragraphs = [line.strip().lower().split(' . ')
                 for line in lines if len(line.split(' . ')) >= 2]
    # 打乱段落顺序
    random.shuffle(paragraphs)
    return paragraphs

三、预训练任务的辅助函数

我们需要两个辅助函数,分别为下一句预测和遮蔽语言模型这两个预训练任务生成数据,还有一个填充输入的函数,把数据整理成模型需要的格式。

1. 生成下一句预测任务的数据

下一句预测任务是让模型判断两个句子是不是连续的,所以我们需要生成正样本(两个连续的句子)和负样本(两个不连续的句子),各占 50% 的比例。

python
def _get_next_sentence(sentence, next_sentence, paragraphs):
    # 50%的概率是正样本,也就是两个句子是连续的
    if random.random() < 0.5:
        is_next = True
    else:
        # 50%的概率是负样本,随机选一个其他段落里的句子作为下一句
        next_sentence = random.choice(random.choice(paragraphs))
        is_next = False
    return sentence, next_sentence, is_next

2. 生成遮蔽语言模型任务的数据

遮蔽语言模型是随机遮蔽 15% 的词元,然后让模型预测这些被遮蔽的词元。对于被遮蔽的词元,我们有三种处理方式:

  1. 80% 的时间用<mask>替换;
  2. 10% 的时间用一个随机的词元替换;
  3. 10% 的时间保持原来的词元不变。

首先我们先实现一个函数来处理单个词元:

python
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab):
    # 复制原来的词元列表
    mlm_input_tokens = tokens.copy()
    # 用来存储预测的位置和对应的标签
    pred_positions_and_labels = []
    # 打乱候选预测位置
    random.shuffle(candidate_pred_positions)
    for mlm_pred_position in candidate_pred_positions:
        if len(pred_positions_and_labels) >= num_mlm_preds:
            break
        # 随机选择替换的方式
        masked_token = None
        if random.random() < 0.8:
            # 80%的概率用<mask>替换
            masked_token = '<mask>'
        else:
            if random.random() < 0.5:
                # 10%的概率用随机词元替换
                masked_token = random.choice(list(vocab.token_to_idx.items()))[0]
            else:
                # 10%的概率保持原来的词元不变
                masked_token = tokens[mlm_pred_position]
        # 替换词元
        mlm_input_tokens[mlm_pred_position] = masked_token
        # 记录预测位置和标签
        pred_positions_and_labels.append((mlm_pred_position, tokens[mlm_pred_position]))
    return mlm_input_tokens, pred_positions_and_labels

然后我们实现一个函数来处理整个序列:

python
def _get_mlm_data_from_tokens(tokens, vocab):
    # 候选预测位置是所有不是特殊词元的位置
    candidate_pred_positions = []
    for i, token in enumerate(tokens):
        if token not in ['<cls>', '<sep>']:
            candidate_pred_positions.append(i)
    # 预测15%的词元
    num_mlm_preds = max(1, round(len(tokens) * 0.15))
    # 处理词元
    mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(
        tokens, candidate_pred_positions, num_mlm_preds, vocab)
    # 把预测位置排序
    pred_positions_and_labels = sorted(pred_positions_and_labels, key=lambda x: x[0])
    pred_positions = [v[0] for v in pred_positions_and_labels]
    mlm_pred_labels = [vocab[v[1]] for v in pred_positions_and_labels]
    return mlm_input_tokens, pred_positions, mlm_pred_labels

3. 填充输入数据

因为每个序列的长度不一样,我们需要把它们填充到固定的长度,这样才能批量处理。

python
def _pad_bert_inputs(examples, max_len, vocab):
    # 最大的预测数量是max_len的15%
    max_num_mlm_preds = round(max_len * 0.15)
    all_token_ids, all_segments, valid_lens = [], [], []
    all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
    nsp_labels = []
    for (token_ids, pred_positions, mlm_pred_label_ids, segments, is_next) in examples:
        # 填充词元id
        all_token_ids.append(np.array(token_ids + [vocab['<pad>']] * (max_len - len(token_ids)), dtype='int32'))
        # 填充段落标识
        all_segments.append(np.array(segments + [0] * (max_len - len(segments)), dtype='int32'))
        # 有效长度,不包括填充的部分
        valid_lens.append(np.array(len(token_ids), dtype='float32'))
        # 填充预测位置
        all_pred_positions.append(np.array(pred_positions + [0] * (max_num_mlm_preds - len(pred_positions)), dtype='int32'))
        # 填充权重,填充的部分权重为0,不会计算损失
        all_mlm_weights.append(np.array([1.0] * len(mlm_pred_label_ids) + [0.0] * (max_num_mlm_preds - len(pred_positions)), dtype='float32'))
        # 填充预测标签
        all_mlm_labels.append(np.array(mlm_pred_label_ids + [0] * (max_num_mlm_preds - len(mlm_pred_label_ids)), dtype='int32'))
        # 下一句预测的标签
        nsp_labels.append(np.array(is_next))
    return (all_token_ids, all_segments, valid_lens, all_pred_positions, all_mlm_weights, all_mlm_labels, nsp_labels)

四、把文本转换成预训练数据集

我们定义一个数据集类,把上面的辅助函数整合起来,把原始的文本转换成 BERT 预训练需要的格式。

python
class _WikiTextDataset(torch.utils.data.Dataset):
    def __init__(self, paragraphs, max_len):
        # 把每个段落拆成词元
        paragraphs = [d2l.tokenize(paragraph, token='word') for paragraph in paragraphs]
        # 把所有句子收集起来
        sentences = [sentence for paragraph in paragraphs for sentence in paragraph]
        # 构建词表,过滤掉出现次数少于5次的词元,添加特殊词元
        self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>'])
        # 生成预训练样本
        examples = []
        for paragraph in paragraphs:
            examples.extend(_get_nsp_data_from_paragraph(paragraph, paragraphs, self.vocab, max_len))
        # 填充数据
        (self.all_token_ids, self.all_segments, self.valid_lens,
         self.all_pred_positions, self.all_mlm_weights, self.all_mlm_labels,
         self.nsp_labels) = _pad_bert_inputs(examples, max_len, self.vocab)

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx], self.valid_lens[idx],
                self.all_pred_positions[idx], self.all_mlm_weights[idx], self.all_mlm_labels[idx],
                self.nsp_labels[idx])

    def __len__(self):
        return len(self.all_token_ids)

def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):
    # 生成下一句预测的样本
    nsp_examples = []
    for i in range(len(paragraph) - 1):
        # 取连续的两个句子作为正样本的候选
        tokens_a, tokens_b, is_next = _get_next_sentence(paragraph[i], paragraph[i+1], paragraphs)
        # 把两个句子拼接成BERT的输入格式
        tokens = ['<cls>'] + tokens_a + ['<sep>'] + tokens_b + ['<sep>']
        # 如果长度超过max_len,截断
        if len(tokens) > max_len:
            tokens = tokens[:max_len]
        # 生成遮蔽语言模型的数据
        mlm_input_tokens, pred_positions, mlm_pred_labels = _get_mlm_data_from_tokens(tokens, vocab)
        # 把词元转换成id
        token_ids = vocab[mlm_input_tokens]
        # 段落标识,<cls>和tokens_a是0,tokens_b是1
        segments = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)
        if len(segments) > max_len:
            segments = segments[:max_len]
        # 记录样本
        nsp_examples.append((token_ids, pred_positions, mlm_pred_labels, segments, is_next))
    return nsp_examples

五、加载数据集

最后我们定义一个函数来加载数据集,生成数据迭代器。

python
def load_data_wiki(batch_size, max_len):
    """加载WikiText-2数据集"""
    num_workers = d2l.get_dataloader_workers()
    # 下载并解压数据集
    data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')
    # 读取段落
    paragraphs = _read_wiki(data_dir)
    # 创建数据集
    train_set = _WikiTextDataset(paragraphs, max_len)
    # 创建数据迭代器
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                            shuffle=True, num_workers=num_workers)
    return train_iter, train_set.vocab

我们可以测试一下这个函数:

python
batch_size = 512
max_len = 64
train_iter, vocab = load_data_wiki(batch_size, max_len)
# 打印小批量的形状
for batch in train_iter:
    tokens, segments, valid_lens, pred_positions, mlm_weights, mlm_labels, nsp_labels = batch
    print('词元id形状:', tokens.shape)
    print('段落标识形状:', segments.shape)
    print('有效长度形状:', valid_lens.shape)
    print('预测位置形状:', pred_positions.shape)
    print('遮蔽语言模型权重形状:', mlm_weights.shape)
    print('遮蔽语言模型标签形状:', mlm_labels.shape)
    print('下一句预测标签形状:', nsp_labels.shape)
    break

运行结果大概是:

Plain
词元id形状: torch.Size([512, 64])
段落标识形状: torch.Size([512, 64])
有效长度形状: torch.Size([512])
预测位置形状: torch.Size([512, 10])
遮蔽语言模型权重形状: torch.Size([512, 10])
遮蔽语言模型标签形状: torch.Size([512, 10])
下一句预测标签形状: torch.Size([512])

六、小结

  1. 我们可以用 WikiText-2 数据集来预训练 BERT,它比 PTB 数据集更适合,保留了更多的信息。

  2. 我们需要为 BERT 的两个预训练任务生成数据:下一句预测任务生成正样本和负样本,遮蔽语言模型任务随机遮蔽 15% 的词元并生成预测标签。

  3. 我们需要把原始文本转换成 BERT 需要的格式,包括词元 id、段落标识、有效长度、预测位置、权重和标签等,并且填充到固定长度。

  4. 我们可以用自定义的数据集类来处理这些数据,生成数据迭代器供模型训练使用。

(注:文档部分内容可能由 AI 生成) 源地址

京ICP备2024093538号-1