Skip to content

自然语言推断与数据集

一、概述

之前我们学的情感分析,是把单个文本分类成积极、消极这类固定类别。但实际场景里,我们经常需要处理两个文本之间的关系,比如判断一个句子能不能从另一个句子里推断出来,或者判断两个句子是不是语义重复,这就需要用到自然语言推断的任务。这个任务在信息检索、开放领域问答系统里都有很多应用,比如在问答系统里,我们需要判断用户的问题能不能从知识库的文本里推断出答案。

二、什么是自然语言推断

自然语言推断的核心是判断 ** 假设(hypothesis)能不能从前提(premise)** 里推断出来,前提和假设都是文本序列。简单来说,就是判断两个文本之间的逻辑关系,这种关系分成三种:

  1. 蕴涵(entailment):假设可以从前提里直接推断出来。 比如:

    • 前提:两个女人拥抱在一起。

    • 假设:两个女人在示爱。 从 “拥抱” 可以推断出 “示爱”,所以这是蕴涵关系。

  2. 矛盾(contradiction):假设的反面可以从前提里推断出来,也就是前提和假设完全相反。 比如:

    • 前提:一名男子正在运行 Dive Into Deep Learning 的编码示例。

    • 假设:该男子正在睡觉。 运行编码示例和睡觉是完全相反的状态,所以这是矛盾关系。

  3. 中性(neutral):既不是蕴涵也不是矛盾的情况,没法从前提里推断出假设的对错。 比如:

    • 前提:音乐家们正在为我们表演。

    • 假设:音乐家很有名。 从 “表演” 这个信息,既不能推断出音乐家 “有名”,也不能推断出 “没名”,所以这是中性关系。

三、斯坦福自然语言推断(SNLI)数据集

我们用的是斯坦福的 SNLI 数据集,这是一个很常用的自然语言推断基准数据集,它包含了 50 多万个带标签的英语句子对。其中训练集大概有 55 万对,测试集大概有 1 万对,而且 “蕴涵”“矛盾”“中性” 这三个标签的数量是平衡的,不会出现某个标签特别多的情况。

四、数据读取和预处理

1. 读取数据集

首先我们需要读取这个数据集,并且提取出我们需要的信息(前提、假设、标签),去掉一些没用的信息。我们可以用 Python 实现一个读取函数:

python
import os
import re
import torch
from d2l import torch as d2l

def read_snli(data_dir, is_train):
    """读取SNLI数据集,返回前提、假设和标签"""
    def extract_text(s):
        # 去掉文本里的括号,合并多余的空格
        s = re.sub('\(', '', s)
        s = re.sub('\)', '', s)
        s = re.sub('\s{2,}', ' ', s)
        return s.strip()
    
    # 把标签转换成数字,方便模型处理
    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
    # 选择读取训练集还是测试集
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt' if is_train else 'snli_1.0_test.txt')
    
    with open(file_name, 'r') as f:
        # 跳过第一行表头,然后按制表符拆分每行内容
        rows = [row.split('\t') for row in f.readlines()[1:]]
        # 提取前提、假设和标签,只保留有有效标签的行
        premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
        hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set]
        labels = [label_set[row[0]] for row in rows if row[0] in label_set]
    
    return premises, hypotheses, labels

我们可以测试一下这个函数,看看读取的结果:

python
# 读取数据集
data_dir = d2l.download_extract('SNLI')
train_data = read_snli(data_dir, True)
test_data = read_snli(data_dir, False)

# 打印几个例子
print("前提:", train_data[0][0])
print("假设:", train_data[1][0])
print("标签:", train_data[2][0])

运行结果大概是:

Plain
前提: A person on a horse jumps over a broken down airplane .
假设: A person is training his horse for a competition .
标签: 2

这里的标签 2 代表中性关系。

2. 自定义数据集类

因为每个文本的长度不一样,我们需要把它们处理成固定长度,这样才能批量输入模型。我们可以自定义一个数据集类,把文本转换成词元,然后填充或截断到固定长度:

python
class SNLIDataset(torch.utils.data.Dataset):
    """自定义SNLI数据集类,处理文本成固定长度"""
    def __init__(self, dataset, num_steps, vocab=None):
        self.num_steps = num_steps
        # 把前提和假设都拆分成词元
        all_premise_tokens = d2l.tokenize(dataset[0])
        all_hypothesis_tokens = d2l.tokenize(dataset[1])
        
        # 如果没有提供词表,就用训练数据创建词表,过滤掉出现次数少于5次的稀有词
        if vocab is None:
            self.vocab = d2l.Vocab(all_premise_tokens + all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab
        
        # 把词元序列填充或截断到固定长度
        self.premises = self._pad(all_premise_tokens)
        self.hypotheses = self._pad(all_hypothesis_tokens)
        # 把标签转换成张量
        self.labels = torch.tensor(dataset[2])
    
    def _pad(self, lines):
        # 把每个词元序列填充或截断到num_steps长度,用<pad>填充短序列
        return torch.tensor([d2l.truncate_pad(self.vocab[line], self.num_steps, self.vocab['<pad>']) for line in lines])
    
    def __getitem__(self, idx):
        # 返回指定索引的(前提,假设)和标签
        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]
    
    def __len__(self):
        # 返回数据集的大小
        return len(self.premises)

3. 整合代码加载数据

最后我们可以写一个函数,整合上面的步骤,下载数据集,创建数据迭代器,方便后续训练模型:

python
def load_data_snli(batch_size, num_steps=50):
    """加载SNLI数据集,返回训练迭代器、测试迭代器和词表"""
    num_workers = d2l.get_dataloader_workers()
    data_dir = d2l.download_extract('SNLI')
    
    # 读取训练和测试数据
    train_data = read_snli(data_dir, True)
    test_data = read_snli(data_dir, False)
    
    # 创建训练集和测试集,测试集用训练集的词表,避免出现新词
    train_set = SNLIDataset(train_data, num_steps)
    test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
    
    # 创建数据迭代器
    train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(test_set, batch_size, shuffle=False, num_workers=num_workers)
    
    return train_iter, test_iter, train_set.vocab

我们可以测试一下这个函数:

python
# 加载数据,批量大小32,文本长度50
train_iter, test_iter, vocab = load_data_snli(32, 50)

# 查看一个批次的数据
batch = next(iter(train_iter))
print("前提的形状:", batch[0][0].shape)
print("假设的形状:", batch[0][1].shape)
print("标签的形状:", batch[1].shape)
print("词表大小:", len(vocab))

运行结果大概是:

Plain
前提的形状: torch.Size([32, 50])
假设的形状: torch.Size([32, 50])
标签的形状: torch.Size([32])
词表大小: 18678

这说明每个批次有 32 个样本,每个文本都被处理成了 50 个词元的长度,词表有 18678 个词。

五、小结

  1. 自然语言推断是判断两个文本之间的逻辑关系,有蕴涵、矛盾、中性三种关系,这个任务在信息检索、问答系统里都有应用。

  2. SNLI 是常用的自然语言推断数据集,有 50 多万个带标签的句子对,训练集和测试集的标签都是平衡的。

  3. 我们可以用自定义的数据集类把文本处理成固定长度的词元序列,方便模型批量处理,而且测试集要使用训练集的词表,避免出现模型没见过的新词。

(注:文档部分内容可能由 AI 生成) 源地址

京ICP备2024093538号-1