PyTorch教程-16.7。自然语言推理:微调 BERT

电子说

1.3w人已加入

描述

在本章前面的部分中,我们为 SNLI 数据集上的自然语言推理任务(如第 16.4 节所述)设计了一个基于注意力的架构(第16.5节)。现在我们通过微调 BERT 重新审视这个任务。正如16.6 节所讨论的 ,自然语言推理是一个序列级文本对分类问题,微调 BERT 只需要一个额外的基于 MLP 的架构,如图 16.7.1所示。

pytorch

图 16.7.1本节将预训练的 BERT 提供给基于 MLP 的自然语言推理架构。

在本节中,我们将下载预训练的小型 BERT 版本,然后对其进行微调以在 SNLI 数据集上进行自然语言推理。

 

import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l

 

 

import json
import multiprocessing
import os
from mxnet import gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

 

16.7.1。加载预训练的 BERT

我们已经在第 15.9 节和第 15.10 节中解释了如何在 WikiText-2 数据集上预训练 BERT (请注意,原始 BERT 模型是在更大的语料库上预训练的)。如15.10 节所述,原始 BERT 模型有数亿个参数。在下文中,我们提供了两个版本的预训练 BERT:“bert.base”与需要大量计算资源进行微调的原始 BERT 基础模型差不多大,而“bert.small”是一个小版本方便演示。

 

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
               '225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
               'c72329e68a732bef0452e4b96a1c341c8910f81f')

 

 

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.zip',
               '7b3820b35da691042e5d34c0971ac3edbd80d3f4')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.zip',
               'a4e718a47137ccd1809c9107ab4f5edd317bae2c')

 

预训练的 BERT 模型都包含一个定义词汇集的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现以下load_pretrained_model 函数来加载预训练的 BERT 参数。

 

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
             num_heads, num_blks, dropout, max_len, devices):
  data_dir = d2l.download_extract(pretrained_model)
  # Define an empty vocabulary to load the predefined vocabulary
  vocab = d2l.Vocab()
  vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
  vocab.token_to_idx = {token: idx for idx, token in enumerate(
    vocab.idx_to_token)}
  bert = d2l.BERTModel(
    len(vocab), num_hiddens, ffn_num_hiddens=ffn_num_hiddens, num_heads=4,
    num_blks=2, dropout=0.2, max_len=max_len)
  # Load pretrained BERT parameters
  bert.load_state_dict(torch.load(os.path.join(data_dir,
                         'pretrained.params')))
  return bert, vocab

 

 

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
             num_heads, num_blks, dropout, max_len, devices):
  data_dir = d2l.download_extract(pretrained_model)
  # Define an empty vocabulary to load the predefined vocabulary
  vocab = d2l.Vocab()
  vocab.idx_to_token = json.load(open(os.path.join(data_dir, 'vocab.json')))
  vocab.token_to_idx = {token: idx for idx, token in enumerate(
    vocab.idx_to_token)}
  bert = d2l.BERTModel(len(vocab), num_hiddens, ffn_num_hiddens, num_heads,
             num_blks, dropout, max_len)
  # Load pretrained BERT parameters
  bert.load_parameters(os.path.join(data_dir, 'pretrained.params'),
             ctx=devices)
  return bert, vocab

 

为了便于在大多数机器上进行演示,我们将在本节中加载和微调预训练 BERT 的小型版本(“bert.small”)。在练习中,我们将展示如何微调更大的“bert.base”以显着提高测试准确性。

 

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
  'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
  num_blks=2, dropout=0.1, max_len=512, devices=devices)

 

 

Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...

 

 

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
  'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
  num_blks=2, dropout=0.1, max_len=512, devices=devices)

 

 

Downloading ../data/bert.small.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.zip...

 

16.7.2。微调 BERT 的数据集

对于 SNLI 数据集上的下游任务自然语言推理,我们定义了一个自定义的数据集类SNLIBERTDataset。在每个示例中,前提和假设形成一对文本序列,并被打包到一个 BERT 输入序列中,如图 16.6.2所示。回想第 15.8.4 节 ,段 ID 用于区分 BERT 输入序列中的前提和假设。对于 BERT 输入序列 ( max_len) 的预定义最大长度,输入文本对中较长者的最后一个标记会不断被删除,直到max_len满足为止。为了加速生成用于微调 BERT 的 SNLI 数据集,我们使用 4 个工作进程并行生成训练或测试示例。

 

class SNLIBERTDataset(torch.utils.data.Dataset):
  def __init__(self, dataset, max_len, vocab=None):
    all_premise_hypothesis_tokens = [[
      p_tokens, h_tokens] for p_tokens, h_tokens in zip(
      *[d2l.tokenize([s.lower() for s in sentences])
       for sentences in dataset[:2]])]

    self.labels = torch.tensor(dataset[2])
    self.vocab = vocab
    self.max_len = max_len
    (self.all_token_ids, self.all_segments,
     self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
    print('read ' + str(len(self.all_token_ids)) + ' examples')

  def _preprocess(self, all_premise_hypothesis_tokens):
    pool = multiprocessing.Pool(4) # Use 4 worker processes
    out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
    all_token_ids = [
      token_ids for token_ids, segments, valid_len in out]
    all_segments = [segments for token_ids, segments, valid_len in out]
    valid_lens = [valid_len for token_ids, segments, valid_len in out]
    return (torch.tensor(all_token_ids, dtype=torch.long),
        torch.tensor(all_segments, dtype=torch.long),
        torch.tensor(valid_lens))

  def _mp_worker(self, premise_hypothesis_tokens):
    p_tokens, h_tokens = premise_hypothesis_tokens
    self._truncate_pair_of_tokens(p_tokens, h_tokens)
    tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
    token_ids = self.vocab[tokens] + [self.vocab['']] 
               * (self.max_len - len(tokens))
    segments = segments + [0] * (self.max_len - len(segments))
    valid_len = len(tokens)
    return token_ids, segments, valid_len

  def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
    # Reserve slots for '', '', and '' tokens for the BERT
    # input
    while len(p_tokens) + len(h_tokens) > self.max_len - 3:
      if len(p_tokens) > len(h_tokens):
        p_tokens.pop()
      else:
        h_tokens.pop()

  def __getitem__(self, idx):
    return (self.all_token_ids[idx], self.all_segments[idx],
        self.valid_lens[idx]), self.labels[idx]

  def __len__(self):
    return len(self.all_token_ids)

 

 

class SNLIBERTDataset(gluon.data.Dataset):
  def __init__(self, dataset, max_len, vocab=None):
    all_premise_hypothesis_tokens = [[
      p_tokens, h_tokens] for p_tokens, h_tokens in zip(
      *[d2l.tokenize([s.lower() for s in sentences])
       for sentences in dataset[:2]])]

    self.labels = np.array(dataset[2])
    self.vocab = vocab
    self.max_len = max_len
    (self.all_token_ids, self.all_segments,
     self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
    print('read ' + str(len(self.all_token_ids)) + ' examples')

  def _preprocess(self, all_premise_hypothesis_tokens):
    pool = multiprocessing.Pool(4) # Use 4 worker processes
    out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
    all_token_ids = [
      token_ids for token_ids, segments, valid_len in out]
    all_segments = [segments for token_ids, segments, valid_len in out]
    valid_lens = [valid_len for token_ids, segments, valid_len in out]
    return (np.array(all_token_ids, dtype='int32'),
        np.array(all_segments, dtype='int32'),
        np.array(valid_lens))

  def _mp_worker(self, premise_hypothesis_tokens):
    p_tokens, h_tokens = premise_hypothesis_tokens
    self._truncate_pair_of_tokens(p_tokens, h_tokens)
    tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
    token_ids = self.vocab[tokens] + [self.vocab['']] 
               * (self.max_len - len(tokens))
    segments = segments + [0] * (self.max_len - len(segments))
    valid_len = len(tokens)
    return token_ids, segments, valid_len

  def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
    # Reserve slots for '', '', and '' tokens for the BERT
    # input
    while len(p_tokens) + len(h_tokens) > self.max_len - 3:
      if len(p_tokens) > len(h_tokens):
        p_tokens.pop()
      else:
        h_tokens.pop()

  def __getitem__(self, idx):
    return (self.all_token_ids[idx], self.all_segments[idx],
        self.valid_lens[idx]), self.labels[idx]

  def __len__(self):
    return len(self.all_token_ids)

 

下载 SNLI 数据集后,我们通过实例化SNLIBERTDataset类来生成训练和测试示例。此类示例将在自然语言推理的训练和测试期间以小批量读取。

 

# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
                  num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                 num_workers=num_workers)

 

 

read 549367 examples
read 9824 examples

 

 

# Reduce `batch_size` if there is an out of memory error. In the original BERT
# model, `max_len` = 512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True,
                  num_workers=num_workers)
test_iter = gluon.data.DataLoader(test_set, batch_size,
                 num_workers=num_workers)

 

 

read 549367 examples
read 9824 examples

 

16.7.3。微调 BERT

如图16.6.2所示,为自然语言推理微调 BERT 只需要一个额外的 MLP,该 MLP 由两个完全连接的层组成(参见下一类中的self.hidden和)。该 MLP 将特殊“”标记的 BERT 表示形式(对前提和假设的信息进行编码)转换为自然语言推理的三个输出:蕴含、矛盾和中性。self.outputBERTClassifier

 

class BERTClassifier(nn.Module):
  def __init__(self, bert):
    super(BERTClassifier, self).__init__()
    self.encoder = bert.encoder
    self.hidden = bert.hidden
    self.output = nn.LazyLinear(3)

  def forward(self, inputs):
    tokens_X, segments_X, valid_lens_x = inputs
    encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
    return self.output(self.hidden(encoded_X[:, 0, :]))

 

 

class BERTClassifier(nn.Block):
  def __init__(self, bert):
    super(BERTClassifier, self).__init__()
    self.encoder = bert.encoder
    self.hidden = bert.hidden
    self.output = nn.Dense(3)

  def forward(self, inputs):
    tokens_X, segments_X, valid_lens_x = inputs
    encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
    return self.output(self.hidden(encoded_X[:, 0, :]))

 

接下来,预训练的 BERT 模型bert被输入到 下游应用程序的BERTClassifier实例中。net在 BERT 微调的常见实现中,只会net.output从头学习附加 MLP ( ) 输出层的参数。net.encoder预训练的 BERT 编码器 ( ) 和附加 MLP 的隐藏层 ( )的所有参数都net.hidden将被微调。

 

net = BERTClassifier(bert)

 

 

net = BERTClassifier(bert)
net.output.initialize(ctx=devices)

 

回想一下15.8 节中类MaskLM和 NextSentencePred类在它们使用的 MLP 中都有参数。这些参数是预训练 BERT 模型中参数bert的一部分,因此也是net. 然而,这些参数仅用于计算预训练期间的掩码语言建模损失和下一句预测损失。MaskLM这两个损失函数与微调下游应用程序无关,因此在微调 BERT 时,在和中使用的 MLP 的参数NextSentencePred不会更新(失效)。

为了允许具有陈旧梯度的参数,在的函数 ignore_stale_grad=True中设置了标志 。我们使用此函数使用SNLI 的训练集 ( ) 和测试集 ( )来训练和评估模型。由于计算资源有限,训练和测试的准确性可以进一步提高:我们将其讨论留在练习中。stepd2l.train_batch_ch13nettrain_itertest_iter

 

lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
net(next(iter(train_iter))[0])
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

 

 

loss 0.519, train acc 0.791, test acc 0.782
9226.8 examples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]

 

pytorch

 

lr, num_epochs = 1e-4, 5
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices,
        d2l.split_batch_multi_inputs)

 

 

loss 0.477, train acc 0.810, test acc 0.785
4626.9 examples/sec on [gpu(0), gpu(1)]

 

pytorch

16.7.4。概括

我们可以为下游应用微调预训练的 BERT 模型,例如 SNLI 数据集上的自然语言推理。

在微调期间,BERT 模型成为下游应用模型的一部分。仅与预训练损失相关的参数在微调期间不会更新。

16.7.5。练习

如果您的计算资源允许,微调一个更大的预训练 BERT 模型,该模型与原始 BERT 基础模型差不多大。将函数中的参数设置load_pretrained_model为:将“bert.small”替换为“bert.base”,将 、 、 和 的值分别增加到 num_hiddens=256768、3072、12ffn_num_hiddens=512和num_heads=412 num_blks=2。通过增加微调周期(并可能调整其他超参数),您能否获得高于 0.86 的测试精度?

如何根据长度比截断一对序列?比较这对截断方法和类中使用的方法 SNLIBERTDataset。他们的优缺点是什么?

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分