×

PyTorch教程9.2之将原始文本转换为序列数据

消耗积分:0 | 格式:pdf | 大小:0.33 MB | 2023-06-05

分享资料个

在本书中,我们经常会使用表示为单词、字符或单词序列的文本数据。首先,我们需要一些基本工具来将原始文本转换为适当形式的序列。典型的预处理流水线执行以下步骤:

  1. 将文本作为字符串加载到内存中。

  2. 将字符串拆分为标记(例如,单词或字符)。

  3. 构建一个词汇词典,将每个词汇元素与一个数字索引相关联。

  4. 将文本转换为数字索引序列。

import collections
import random
import re
import torch
from d2l import torch as d2l
import collections
import random
import re
from mxnet import np, npx
from d2l import mxnet as d2l

npx.set_np()
import collections
import random
import re
import jax
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import collections
import random
import re
import tensorflow as tf
from d2l import tensorflow as d2l

9.2.1. 读取数据集

在这里,我们将使用 HG Wells 的The Time Machine,这是一本 30000 多字的书。虽然实际应用程序通常会涉及大得多的数据集,但这足以演示预处理管道。以下_download方法将原始文本读入字符串。

class TimeMachine(d2l.DataModule): #@save
  """The Time Machine dataset."""
  def _download(self):
    fname = d2l.download(d2l.DATA_URL + 'timemachine.txt', self.root,
               '090b5e7e70c295757f55df93cb0a180b9691891a')
    with open(fname) as f:
      return f.read()

data = TimeMachine()
raw_text = data._download()
raw_text[:60]
'时间机器,HG Wells [1898]nnnnnInnnThe Time Tra'
class TimeMachine(d2l.DataModule): #@save
  """The Time Machine dataset."""
  def _download(self):
    fname = d2l.download(d2l.DATA_URL + 'timemachine.txt', self.root,
               '090b5e7e70c295757f55df93cb0a180b9691891a')
    with open(fname) as f:
      return f.read()

data = TimeMachine()
raw_text = data._download()
raw_text[:60]
Downloading ../data/timemachine.txt from http://d2l-data.s3-accelerate.amazonaws.com/timemachine.txt...
'The Time Machine, by H. G. Wells [1898]nnnnnInnnThe Time Tra'
class TimeMachine(d2l.DataModule): #@save
  """The Time Machine dataset."""
  def _download(self):
    fname = d2l.download(d2l.DATA_URL + 'timemachine.txt', self.root,
               '090b5e7e70c295757f55df93cb0a180b9691891a')
    with open(fname) as f:
      return f.read()

data = TimeMachine()
raw_text = data._download()
raw_text[:60]
'The Time Machine, by H. G. Wells [1898]nnnnnInnnThe Time Tra'
class TimeMachine(d2l.DataModule): #@save
  """The Time Machine dataset."""
  def _download(self):
    fname = d2l.download(d2l.DATA_URL + 'timemachine.txt', self.root,
               '090b5e7e70c295757f55df93cb0a180b9691891a')
    with open(fname) as f:
      return f.read()

data = TimeMachine()
raw_text = data._download()
raw_text[:60]
'The Time Machine, by H. G. Wells [1898]nnnnnInnnThe Time Tra'

为简单起见,我们在预处理原始文本时忽略标点符号和大写字母。

@d2l.add_to_class(TimeMachine) #@save
def _preprocess(self, text):
  return re.sub('[^A-Za-z]+', ' ', text).lower()

text = data._preprocess(raw_text)
text[:60]
'the time machine by h g wells i the time traveller for so it'
@d2l.add_to_class(TimeMachine) #@save
def _preprocess(self, text):
  return re.sub('[^A-Za-z]+', ' ', text).lower()

text = data._preprocess(raw_text)
text[:60]
'the time machine by h g wells i the time traveller for so it'
@d2l.add_to_class(TimeMachine) #@save
def _preprocess(self, text):
  return re.sub('[^A-Za-z]+', ' ', text).lower()

text = data._preprocess(raw_text)
text[:60]
'the time machine by h g wells i the time traveller for so it'
@d2l.add_to_class(TimeMachine) #@save
def _preprocess(self, text):
  return re.sub('[^A-Za-z]+', ' ', text).lower()

text = data._preprocess(raw_text)
text[:60]
'the time machine by h g wells i the time traveller for so it'

9.2.2. 代币化

标记是文本的原子(不可分割)单元。每个时间步对应 1 个 token,但究竟什么是 token 是一种设计选择。例如,我们可以将句子“Baby needs a new pair of shoes”表示为一个包含 7 个单词的序列,其中所有单词的集合包含一个很大的词汇表(通常是数万或数十万个单词)。或者我们将同一个句子表示为更长的 30 个字符序列,使用更小的词汇表(只有 256 个不同的 ASCII 字符)。下面,我们将预处理后的文本标记为一系列字符。

@d2l.add_to_class(TimeMachine) #@save
def _tokenize(self, text):
  return list(text)

tokens = data._tokenize(text)
','.join(tokens[:30])
't,h,e, ,t,i,m,e, ,m,a,c,h,i,n,e, ,b,y, ,h, ,g, ,w,e,l,l,s, '
@d2l.add_to_class(TimeMachine) #@save
def _tokenize(self, text):
  return list(text)

tokens = data._tokenize(text)
','.join(tokens[:30])
't,h,e, ,t,i,m,e, ,m,a,c,h,i,n,e, ,b,y, ,h, ,g, ,w,e,l,l,s, '
@d2l.add_to_class(TimeMachine) #@save
def _tokenize(self, text):
  return list(text)

tokens = data._tokenize(text)
','.join(tokens[:30])
't,h,e, ,t,i,m,e, ,m,a,c,h,i,n,e, ,b,y, ,h, ,g, ,w,e,l,l,s, '
@d2l.add_to_class(TimeMachine) #@save
def _tokenize(self, text):
  return list(text)

tokens = data._tokenize(text)
','.join(tokens[:30])
't,h,e, ,t,i,m,e, ,m,a,c,h,i,n,e, ,b,y, ,h, ,g, ,w,e,l,l,s, '

9.2.3. 词汇

这些标记仍然是字符串。然而,我们模型的输入最终必须由数值输入组成。接下来,我们介绍一个用于构建词汇表的类,即,将每个不同的标记值与唯一索引相关联的对象。首先,我们确定训练语料库中的唯一标记集然后我们为每个唯一标记分配一个数字索引。为方便起见,通常会删除不常用的词汇元素。Whenever we encounter a token at training or test time that had not been previously seen or was dropped from the vocabulary, we represent it by a special “” token, signifying that this is an unknown value.

class Vocab: #@save
  """Vocabulary for text."""
  def __init__(self, tokens=[], min_freq=0, reserved_tokens=[]):
    # Flatten a 2D list if needed
    if tokens and isinstance(tokens[0], list):
      tokens = [token for line in tokens for token in line]
    # Count token frequencies
    counter = collections.Counter(tokens)
    self.token_freqs = sorted(counter.items(), key=lambda x: x[1],
                 reverse=True)
    # The list of unique tokens
    self.idx_to_token = list(sorted(set([''] + reserved_tokens + [
      token for token, freq in self.token_freqs if freq >= min_freq])))
    self.token_to_idx = {token: idx
               for idx, token in enumerate(self.idx_to_token)}

  def __len__(self):
    return len(self.idx_to_token)

  def __getitem__(self, tokens):
    if not isinstance(tokens, (list, tuple)):
      return self.token_to_idx.get(tokens,

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !

'+ '

'+ '

'+ ''+ '
'+ ''+ ''+ '
'+ ''+ '' ); $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code ==5){ $(pop_this).attr('href',"/login/index.html"); return false } if(data.code == 2){ //跳转到VIP升级页面 window.location.href="//m.obk20.com/vip/index?aid=" + webid return false } //是会员 if (data.code > 0) { $('body').append(htmlSetNormalDownload); var getWidth=$("#poplayer").width(); $("#poplayer").css("margin-left","-"+getWidth/2+"px"); $('#tips').html(data.msg) $('.download_confirm').click(function(){ $('#dialog').remove(); }) } else { var down_url = $('#vipdownload').attr('data-url'); isBindAnalysisForm(pop_this, down_url, 1) } }); }); //是否开通VIP $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code == 2 || data.code ==5){ //跳转到VIP升级页面 $('#vipdownload>span').text("开通VIP 免费下载") return false }else{ // 待续费 if(data.code == 3) { vipExpiredInfo.ifVipExpired = true vipExpiredInfo.vipExpiredDate = data.data.endoftime } $('#vipdownload .icon-vip-tips').remove() $('#vipdownload>span').text("VIP免积分下载") } }); }).on("click",".download_cancel",function(){ $('#dialog').remove(); }) var setWeixinShare={};//定义默认的微信分享信息,页面如果要自定义分享,直接更改此变量即可 if(window.navigator.userAgent.toLowerCase().match(/MicroMessenger/i) == 'micromessenger'){ var d={ title:'PyTorch教程9.2之将原始文本转换为序列数据',//标题 desc:$('[name=description]').attr("content"), //描述 imgUrl:'https://'+location.host+'/static/images/ele-logo.png',// 分享图标,默认是logo link:'',//链接 type:'',// 分享类型,music、video或link,不填默认为link dataUrl:'',//如果type是music或video,则要提供数据链接,默认为空 success:'', // 用户确认分享后执行的回调函数 cancel:''// 用户取消分享后执行的回调函数 } setWeixinShare=$.extend(d,setWeixinShare); $.ajax({ url:"//www.obk20.com/app/wechat/index.php?s=Home/ShareConfig/index", data:"share_url="+encodeURIComponent(location.href)+"&format=jsonp&domain=m", type:'get', dataType:'jsonp', success:function(res){ if(res.status!="successed"){ return false; } $.getScript('https://res.wx.qq.com/open/js/jweixin-1.0.0.js',function(result,status){ if(status!="success"){ return false; } var getWxCfg=res.data; wx.config({ //debug: true, // 开启调试模式,调用的所有api的返回值会在客户端alert出来,若要查看传入的参数,可以在pc端打开,参数信息会通过log打出,仅在pc端时才会打印。 appId:getWxCfg.appId, // 必填,公众号的唯一标识 timestamp:getWxCfg.timestamp, // 必填,生成签名的时间戳 nonceStr:getWxCfg.nonceStr, // 必填,生成签名的随机串 signature:getWxCfg.signature,// 必填,签名,见附录1 jsApiList:['onMenuShareTimeline','onMenuShareAppMessage','onMenuShareQQ','onMenuShareWeibo','onMenuShareQZone'] // 必填,需要使用的JS接口列表,所有JS接口列表见附录2 }); wx.ready(function(){ //获取“分享到朋友圈”按钮点击状态及自定义分享内容接口 wx.onMenuShareTimeline({ title: setWeixinShare.title, // 分享标题 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享给朋友”按钮点击状态及自定义分享内容接口 wx.onMenuShareAppMessage({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 type: setWeixinShare.type, // 分享类型,music、video或link,不填默认为link dataUrl: setWeixinShare.dataUrl, // 如果type是music或video,则要提供数据链接,默认为空 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ”按钮点击状态及自定义分享内容接口 wx.onMenuShareQQ({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到腾讯微博”按钮点击状态及自定义分享内容接口 wx.onMenuShareWeibo({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ空间”按钮点击状态及自定义分享内容接口 wx.onMenuShareQZone({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); }); }); } }); } function openX_ad(posterid, htmlid, width, height) { if ($(htmlid).length > 0) { var randomnumber = Math.random(); var now_url = encodeURIComponent(window.location.href); var ga = document.createElement('iframe'); ga.src = 'https://www1.elecfans.com/www/delivery/myafr.php?target=_blank&cb=' + randomnumber + '&zoneid=' + posterid+'&prefer='+now_url; ga.width = width; ga.height = height; ga.frameBorder = 0; ga.scrolling = 'no'; var s = $(htmlid).append(ga); } } openX_ad(828, '#berry-300', 300, 250);