×

PyTorch教程4.3之基本分类模型

消耗积分:0 | 格式:pdf | 大小:0.13 MB | 2023-06-05

李雪

分享资料个

您可能已经注意到,在回归的情况下,从头开始的实现和使用框架功能的简洁实现非常相似。分类也是如此。由于本书中的许多模型都处理分类,因此值得添加专门支持此设置的功能。本节为分类模型提供了一个基类,以简化以后的代码。

import torch
from d2l import torch as d2l
from mxnet import autograd, gluon, np, npx
from d2l import mxnet as d2l

npx.set_np()
from functools import partial
import jax
import optax
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l

4.3.1. Classifier_

我们在下面定义Classifier类。在中,validation_step我们报告了验证批次的损失值和分类准确度。我们为每个批次绘制一个更新num_val_batches这有利于在整个验证数据上生成平均损失和准确性。如果最后一批包含的示例较少,则这些平均数并不完全正确,但我们忽略了这一微小差异以保持代码简单。

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def validation_step(self, batch):
    Y_hat = self(*batch[:-1])
    self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
    self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def validation_step(self, batch):
    Y_hat = self(*batch[:-1])
    self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
    self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.

We also redefine the training_step method for JAX since all models that will subclass Classifier later will have a loss that returns auxiliary data. This auxiliary data can be used for models with batch normalization (to be explained in Section 8.5), while in all other cases we will make the loss also return a placeholder (empty dictionary) to represent the auxiliary data.

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def training_step(self, params, batch, state):
    # Here value is a tuple since models with BatchNorm layers require
    # the loss to return auxiliary data
    value, grads = jax.value_and_grad(
      self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
    l, _ = value
    self.plot("loss", l, train=True)
    return value, grads

  def validation_step(self, params, batch, state):
    # Discard the second returned value. It is used for training models
    # with BatchNorm layers since loss also returns auxiliary data
    l, _ = self.loss(params, batch[:-1], batch[-1], state)
    self.plot('loss', l, train=False)
    self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state),
         train=False)

We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def validation_step(self, batch):
    Y_hat = self(*batch[:-1])
    self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
    self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

默认情况下,我们使用随机梯度下降优化器,在小批量上运行,就像我们在线性回归的上下文中所做的那样。

@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  return torch.optim.SGD(self.parameters(), lr=self.lr)
@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  params = self.parameters()
  if isinstance(params, list):
    return d2l.SGD(params, self.lr)
  return gluon.Trainer(params, 'sgd', {'learning_rate': self.lr})
@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  return optax.sgd(self.lr)
@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  return tf.keras.optimizers.SGD(self.lr)

4.3.2. 准确性

给定预测概率分布y_hat,每当我们必须输出硬预测时,我们通常会选择预测概率最高的类别。事实上,许多应用程序需要我们做出选择。例如,Gmail 必须将电子邮件分类为“主要”、“社交”、“更新”、“william hill官网 ”或“垃圾邮件”。它可能会在内部估计概率,但最终它必须在类别中选择一个。

当预测与标签 class 一致时y,它们是正确的。分类准确度是所有正确预测的分数。尽管直接优化精度可能很困难(不可微分),但它通常是我们最关心的性能指标。它通常是基准测试中的相关数量因此,我们几乎总是在训练分类器时报告它。

准确度计算如下。首先,如果y_hat是一个矩阵,我们假设第二个维度存储每个类别的预测分数。我们使用argmax每行中最大条目的索引来获取预测类。然后我们将预测的类别与真实的元素进行比较y由于相等运算符== 对数据类型敏感,因此我们转换 的y_hat数据类型以匹配 的数据类型y结果是一个包含条目 0(假)和 1(真)的张量。求和得出正确预测的数量。

@d2l.add_to_class(Classifier) #@save
def accuracy(self, Y_hat, Y, averaged=True):
  """Compute the number of correct predictions."""
  Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
  preds = Y_hat.argmax(axis=1).type(Y.dtype)
  compare = (preds == Y.reshape(-1)).type(torch.float32)
  return compare.mean() if averaged else compare
@d2l.add_to_class(Classifier) #@save
def accuracy(self, Y_hat, Y, averaged=True):
  """Compute the number of correct predictions."""
  Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
  preds = Y_hat.argmax(axis=1).astype(Y.dtype)
  compare = (preds == Y.reshape(-1)).astype(np.float32)
  return compare.mean() if averaged else compare

@d2l.add_to_class(d2l.Module) #@save
def get_scratch_params(self):
  params = []
  for attr in dir(self):
    a = getattr(self, attr)
    if isinstance(a, np.ndarray):
      params.append(a)
    if isinstance(a, d2l.Module):
      params.extend(a.get_scratch_params())
  return params

@d2l.add_to_class(d2l.Module) #@save
def parameters(self):
  params = self.collect_params()
  return params if isinstance(params, gluon.parameter.ParameterDict) and

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !

'+ '

'+ '

'+ ''+ '
'+ ''+ ''+ '
'+ ''+ '' ); $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code ==5){ $(pop_this).attr('href',"/login/index.html"); return false } if(data.code == 2){ //跳转到VIP升级页面 window.location.href="//m.obk20.com/vip/index?aid=" + webid return false } //是会员 if (data.code > 0) { $('body').append(htmlSetNormalDownload); var getWidth=$("#poplayer").width(); $("#poplayer").css("margin-left","-"+getWidth/2+"px"); $('#tips').html(data.msg) $('.download_confirm').click(function(){ $('#dialog').remove(); }) } else { var down_url = $('#vipdownload').attr('data-url'); isBindAnalysisForm(pop_this, down_url, 1) } }); }); //是否开通VIP $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code == 2 || data.code ==5){ //跳转到VIP升级页面 $('#vipdownload>span').text("开通VIP 免费下载") return false }else{ // 待续费 if(data.code == 3) { vipExpiredInfo.ifVipExpired = true vipExpiredInfo.vipExpiredDate = data.data.endoftime } $('#vipdownload .icon-vip-tips').remove() $('#vipdownload>span').text("VIP免积分下载") } }); }).on("click",".download_cancel",function(){ $('#dialog').remove(); }) var setWeixinShare={};//定义默认的微信分享信息,页面如果要自定义分享,直接更改此变量即可 if(window.navigator.userAgent.toLowerCase().match(/MicroMessenger/i) == 'micromessenger'){ var d={ title:'PyTorch教程4.3之基本分类模型',//标题 desc:$('[name=description]').attr("content"), //描述 imgUrl:'https://'+location.host+'/static/images/ele-logo.png',// 分享图标,默认是logo link:'',//链接 type:'',// 分享类型,music、video或link,不填默认为link dataUrl:'',//如果type是music或video,则要提供数据链接,默认为空 success:'', // 用户确认分享后执行的回调函数 cancel:''// 用户取消分享后执行的回调函数 } setWeixinShare=$.extend(d,setWeixinShare); $.ajax({ url:"//www.obk20.com/app/wechat/index.php?s=Home/ShareConfig/index", data:"share_url="+encodeURIComponent(location.href)+"&format=jsonp&domain=m", type:'get', dataType:'jsonp', success:function(res){ if(res.status!="successed"){ return false; } $.getScript('https://res.wx.qq.com/open/js/jweixin-1.0.0.js',function(result,status){ if(status!="success"){ return false; } var getWxCfg=res.data; wx.config({ //debug: true, // 开启调试模式,调用的所有api的返回值会在客户端alert出来,若要查看传入的参数,可以在pc端打开,参数信息会通过log打出,仅在pc端时才会打印。 appId:getWxCfg.appId, // 必填,公众号的唯一标识 timestamp:getWxCfg.timestamp, // 必填,生成签名的时间戳 nonceStr:getWxCfg.nonceStr, // 必填,生成签名的随机串 signature:getWxCfg.signature,// 必填,签名,见附录1 jsApiList:['onMenuShareTimeline','onMenuShareAppMessage','onMenuShareQQ','onMenuShareWeibo','onMenuShareQZone'] // 必填,需要使用的JS接口列表,所有JS接口列表见附录2 }); wx.ready(function(){ //获取“分享到朋友圈”按钮点击状态及自定义分享内容接口 wx.onMenuShareTimeline({ title: setWeixinShare.title, // 分享标题 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享给朋友”按钮点击状态及自定义分享内容接口 wx.onMenuShareAppMessage({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 type: setWeixinShare.type, // 分享类型,music、video或link,不填默认为link dataUrl: setWeixinShare.dataUrl, // 如果type是music或video,则要提供数据链接,默认为空 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ”按钮点击状态及自定义分享内容接口 wx.onMenuShareQQ({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到腾讯微博”按钮点击状态及自定义分享内容接口 wx.onMenuShareWeibo({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ空间”按钮点击状态及自定义分享内容接口 wx.onMenuShareQZone({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); }); }); } }); } function openX_ad(posterid, htmlid, width, height) { if ($(htmlid).length > 0) { var randomnumber = Math.random(); var now_url = encodeURIComponent(window.location.href); var ga = document.createElement('iframe'); ga.src = 'https://www1.elecfans.com/www/delivery/myafr.php?target=_blank&cb=' + randomnumber + '&zoneid=' + posterid+'&prefer='+now_url; ga.width = width; ga.height = height; ga.frameBorder = 0; ga.scrolling = 'no'; var s = $(htmlid).append(ga); } } openX_ad(828, '#berry-300', 300, 250);