×

PyTorch教程10.7之用于机器翻译的编码器-解码器Seq2Seq

消耗积分:0 | 格式:pdf | 大小:0.37 MB | 2023-06-05

分享资料个

在所谓的 seq2seq 问题中,如机器翻译(如 第 10.5 节所述),其中输入和输出均由可变长度的未对齐序列组成,我们通常依赖编码器-解码器架构(第10.6 节)。在本节中,我们将演示编码器-解码器架构在机器翻译任务中的应用,其中编码器和解码器均作为 RNN 实现( Cho,2014 年Sutskever等人,2014 年

在这里,编码器 RNN 将可变长度序列作为输入并将其转换为固定形状的隐藏状态。稍后,在 第 11 节中,我们将介绍注意力机制,它允许我们访问编码输入,而无需将整个输入压缩为单个固定长度的表示形式。

然后,为了生成输出序列,一次一个标记,由一个单独的 RNN 组成的解码器模型将在给定输入序列和输出中的前一个标记的情况下预测每个连续的目标标记。在训练期间,解码器通常会以官方“ground-truth”标签中的前面标记为条件。然而,在测试时,我们希望根据已经预测的标记来调节解码器的每个输出。请注意,如果我们忽略编码器,则 seq2seq 架构中的解码器的行为就像普通语言模型一样。图 10.7.1说明了如何在机器翻译中使用两个 RNN 进行序列到序列学习。

https://file.elecfans.com/web2/M00/A9/C8/poYBAGR9N3mATR0fAAFvoV0b0sI161.svg

图 10.7.1使用 RNN 编码器和 RNN 解码器进行序列到序列学习。

图 10.7.1中,特殊的“”标记标志着序列的结束。一旦生成此令牌,我们的模型就可以停止进行预测。在 RNN 解码器的初始时间步,有两个特殊的设计决策需要注意:首先,我们以特殊的序列开始“”标记开始每个输入。其次,我们可以在每个解码时间步将编码器的最终隐藏状态输入解码器Cho等人,2014 年在其他一些设计中,例如Sutskever等人。( 2014 ),RNN 编码器的最终隐藏状态仅在第一个解码步骤用于启动解码器的隐藏状态。

import collections
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
import collections
import math
from mxnet import autograd, gluon, init, np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l

npx.set_np()
import collections
import math
from functools import partial
import jax
import optax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
import collections
import math
import tensorflow as tf
from d2l import tensorflow as d2l

10.7.1。教师强迫

虽然在输入序列上运行编码器相对简单,但如何处理解码器的输入和输出则需要更加小心。最常见的方法有时称为 教师强制在这里,原始目标序列(标记标签)作为输入被送入解码器。更具体地说,特殊的序列开始标记和原始目标序列(不包括最终标记)被连接起来作为解码器的输入,而解码器输出(用于训练的标签)是原始目标序列,移动了一个标记: “”,“Ils”,“regardent”,“。” “Ils”、“regardent”、“.”、“”(图 10.7.1)。

我们在10.5.3 节中的实施为教师强制准备了训练数据,其中用于自监督学习的转移标记类似于9.3 节中的语言模型训练。另一种方法是将来自前一个时间步的预测标记作为当前输入提供给解码器。

下面,我们 将更详细地解释图 10.7.1中描绘的设计。我们将在第 10.5 节中介绍的英语-法语数据集上训练该模型进行机器翻译

10.7.2。编码器

回想一下,编码器将可变长度的输入序列转换为固定形状的上下文变量 c(见图 10.7.1)。

考虑一个单序列示例(批量大小 1)。假设输入序列是x1,…,xT, 这样xt是个 tth令牌。在时间步t, RNN 变换输入特征向量xt为了xt 和隐藏状态ht−1从上一次进入当前隐藏状态ht. 我们可以使用一个函数f表达RNN循环层的变换:

(10.7.1)ht=f(xt,ht−1).

通常,编码器通过自定义函数将所有时间步的隐藏状态转换为上下文变量q:

(10.7.2)c=q(h1,…,hT).

例如,在图 10.7.1中,上下文变量只是隐藏状态hT对应于编码器 RNN 在处理输入序列的最终标记后的表示。

在这个例子中,我们使用单向 RNN 来设计编码器,其中隐藏状态仅取决于隐藏状态时间步和之前的输入子序列。我们还可以使用双向 RNN 构建编码器。在这种情况下,隐藏状态取决于时间步长前后的子序列(包括当前时间步长的输入),它编码了整个序列的信息。

现在让我们来实现 RNN 编码器。请注意,我们使用嵌入层来获取输入序列中每个标记的特征向量。嵌入层的权重是一个矩阵,其中行数对应于输入词汇表的大小 ( vocab_size),列数对应于特征向量的维度 ( embed_size)。对于任何输入令牌索引i,嵌入层获取ith权矩阵的行(从 0 开始)返回其特征向量。在这里,我们使用多层 GRU 实现编码器。

def init_seq2seq(module): #@save
  """Initialize weights for Seq2Seq."""
  if type(module) == nn.Linear:
     nn.init.xavier_uniform_(module.weight)
  if type(module) == nn.GRU:
    for param in module._flat_weights_names:
      if "weight" in param:
        nn.init.xavier_uniform_(module._parameters[param])

class Seq2SeqEncoder(d2l.Encoder): #@save
  """The RNN encoder for sequence to sequence learning."""
  def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,
         dropout=0):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.rnn = d2l.GRU(embed_size, num_hiddens, num_layers, dropout)
    self.apply(init_seq2seq)

  def forward(self, X, *args):
    # X shape: (batch_size, num_steps)
    embs = self.embedding(X.t().type(torch.in

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !

'+ '

'+ '

'+ ''+ '
'+ ''+ ''+ '
'+ ''+ '' ); $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code ==5){ $(pop_this).attr('href',"/login/index.html"); return false } if(data.code == 2){ //跳转到VIP升级页面 window.location.href="//m.obk20.com/vip/index?aid=" + webid return false } //是会员 if (data.code > 0) { $('body').append(htmlSetNormalDownload); var getWidth=$("#poplayer").width(); $("#poplayer").css("margin-left","-"+getWidth/2+"px"); $('#tips').html(data.msg) $('.download_confirm').click(function(){ $('#dialog').remove(); }) } else { var down_url = $('#vipdownload').attr('data-url'); isBindAnalysisForm(pop_this, down_url, 1) } }); }); //是否开通VIP $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code == 2 || data.code ==5){ //跳转到VIP升级页面 $('#vipdownload>span').text("开通VIP 免费下载") return false }else{ // 待续费 if(data.code == 3) { vipExpiredInfo.ifVipExpired = true vipExpiredInfo.vipExpiredDate = data.data.endoftime } $('#vipdownload .icon-vip-tips').remove() $('#vipdownload>span').text("VIP免积分下载") } }); }).on("click",".download_cancel",function(){ $('#dialog').remove(); }) var setWeixinShare={};//定义默认的微信分享信息,页面如果要自定义分享,直接更改此变量即可 if(window.navigator.userAgent.toLowerCase().match(/MicroMessenger/i) == 'micromessenger'){ var d={ title:'PyTorch教程10.7之用于机器翻译的编码器-解码器Seq2Seq',//标题 desc:$('[name=description]').attr("content"), //描述 imgUrl:'https://'+location.host+'/static/images/ele-logo.png',// 分享图标,默认是logo link:'',//链接 type:'',// 分享类型,music、video或link,不填默认为link dataUrl:'',//如果type是music或video,则要提供数据链接,默认为空 success:'', // 用户确认分享后执行的回调函数 cancel:''// 用户取消分享后执行的回调函数 } setWeixinShare=$.extend(d,setWeixinShare); $.ajax({ url:"//www.obk20.com/app/wechat/index.php?s=Home/ShareConfig/index", data:"share_url="+encodeURIComponent(location.href)+"&format=jsonp&domain=m", type:'get', dataType:'jsonp', success:function(res){ if(res.status!="successed"){ return false; } $.getScript('https://res.wx.qq.com/open/js/jweixin-1.0.0.js',function(result,status){ if(status!="success"){ return false; } var getWxCfg=res.data; wx.config({ //debug: true, // 开启调试模式,调用的所有api的返回值会在客户端alert出来,若要查看传入的参数,可以在pc端打开,参数信息会通过log打出,仅在pc端时才会打印。 appId:getWxCfg.appId, // 必填,公众号的唯一标识 timestamp:getWxCfg.timestamp, // 必填,生成签名的时间戳 nonceStr:getWxCfg.nonceStr, // 必填,生成签名的随机串 signature:getWxCfg.signature,// 必填,签名,见附录1 jsApiList:['onMenuShareTimeline','onMenuShareAppMessage','onMenuShareQQ','onMenuShareWeibo','onMenuShareQZone'] // 必填,需要使用的JS接口列表,所有JS接口列表见附录2 }); wx.ready(function(){ //获取“分享到朋友圈”按钮点击状态及自定义分享内容接口 wx.onMenuShareTimeline({ title: setWeixinShare.title, // 分享标题 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享给朋友”按钮点击状态及自定义分享内容接口 wx.onMenuShareAppMessage({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 type: setWeixinShare.type, // 分享类型,music、video或link,不填默认为link dataUrl: setWeixinShare.dataUrl, // 如果type是music或video,则要提供数据链接,默认为空 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ”按钮点击状态及自定义分享内容接口 wx.onMenuShareQQ({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到腾讯微博”按钮点击状态及自定义分享内容接口 wx.onMenuShareWeibo({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ空间”按钮点击状态及自定义分享内容接口 wx.onMenuShareQZone({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); }); }); } }); } function openX_ad(posterid, htmlid, width, height) { if ($(htmlid).length > 0) { var randomnumber = Math.random(); var now_url = encodeURIComponent(window.location.href); var ga = document.createElement('iframe'); ga.src = 'https://www1.elecfans.com/www/delivery/myafr.php?target=_blank&cb=' + randomnumber + '&zoneid=' + posterid+'&prefer='+now_url; ga.width = width; ga.height = height; ga.frameBorder = 0; ga.scrolling = 'no'; var s = $(htmlid).append(ga); } } openX_ad(828, '#berry-300', 300, 250);