×

PyTorch教程19.1之什么是超参数优化

消耗积分:0 | 格式:pdf | 大小:0.29 MB | 2023-06-05

李志静

分享资料个

正如我们在前几章中所见,深度神经网络带有大量在训练过程中学习的参数或权重。除此之外,每个神经网络都有额外的 超参数需要用户配置。例如,为了确保随机梯度下降收敛到训练损失的局部最优(参见第 12 节),我们必须调整学习率和批量大小。为了避免在训练数据集上过度拟合,我们可能必须设置正则化参数,例如权重衰减(参见第 3.7 节)或 dropout(参见 第 5.6 节)). 我们可以通过设置层数和每层单元或过滤器的数量(即权重的有效数量)来定义模型的容量和归纳偏差。

不幸的是,我们不能简单地通过最小化训练损失来调整这些超参数,因为这会导致训练数据过度拟合。例如,将正则化参数(如 dropout 或权重衰减)设置为零会导致较小的训练损失,但可能会损害泛化性能。

https://file.elecfans.com/web2/M00/AA/48/pYYBAGR9PUqAbQbSAAB0OkZMfC0966.svg

图 19.1.1机器学习中的典型工作流程,包括使用不同的超参数多次训练模型。

如果没有不同形式的自动化,就必须以反复试验的方式手动设置超参数,这相当于机器学习工作流程中耗时且困难的部分。例如,考虑在 CIFAR-10 上训练 ResNet(参见第 8.6 节g4dn.xlarge),这需要在 Amazon Elastic Cloud Compute (EC2)实例上训练 2 个多小时。即使只是依次尝试十个超参数配置,这也已经花费了我们大约一天的时间。更糟糕的是,超参数通常不能直接跨架构和数据集传输 Bardenet等人,2013 年Feurer等人,2022 年Wistuba等人,2018 年,并且需要针对每个新任务重新优化。此外,对于大多数超参数,没有经验法则,需要专业知识才能找到合理的值。

超参数优化 (HPO)算法旨在以一种有原则的和自动化的方式解决这个问题 Feurer 和 Hutter,2018 年,将其定义为一个全局优化问题。默认目标是保留验证数据集上的错误,但原则上可以是任何其他业务指标。它可以与次要目标结合或受其约束,例如训练时间、推理时间或模型复杂性。

最近,超参数优化已扩展到神经架构搜索 (NAS) Elsken等人,2018 年Wistuba等人,2019 年,目标是找到全新的神经网络架构。与经典 HPO 相比,NAS 在计算方面的成本更高,并且需要额外的努力才能在实践中保持可行性。HPO 和 NAS 都可以被视为 AutoML 的子领域 ( Hutter et al. , 2019 ),旨在自动化整个 ML 管道。

在本节中,我们将介绍 HPO 并展示我们如何自动找到第 4.5 节介绍的逻辑回归示例的最佳超参数

19.1.1. 优化问题

我们将从一个简单的玩具问题开始:搜索第 4.5 节SoftmaxRegression中 的多类逻辑回归模型的学习率,以最小化 Fashion MNIST 数据集上的验证错误。虽然批量大小或轮数等其他超参数也值得调整,但为简单起见,我们只关注学习率。

import numpy as np
import torch
from scipy import stats
from torch import nn
from d2l import torch as d2l

在运行 HPO 之前,我们首先需要定义两个要素:目标函数和配置空间。

19.1.1.1。目标函数

学习算法的性能可以看作是一个函数 f:X→R从超参数空间映射x∈X到验证损失。对于每一个评价f(x),我们必须训练和验证我们的机器学习模型,对于在大型数据集上训练的深度神经网络,这可能是时间和计算密集型的。鉴于我们的标准f(x)我们的目标是找到 x⋆∈argminx∈Xf(x).

没有简单的方法来计算的梯度f关于 x,因为它需要在整个训练过程中传播梯度。虽然最近有工作 Franceschi等人,2017 年Maclaurin等人,2015 年通过近似“超梯度”驱动 HPO,但现有方法中没有一种与最先进的方法具有竞争力,我们将不在这里讨论它们。此外,评估的计算负担f 要求 HPO 算法以尽可能少的样本接近全局最优。

神经网络的训练是随机的(例如,权重是随机初始化的,mini-batches 是随机采样的),因此我们的观察结果会很嘈杂:y∼f(x)+ϵ,我们通常假设ϵ∼N(0,σ) 观察噪声呈高斯分布。

面对所有这些挑战,我们通常会尝试快速识别一小组性能良好的超参数配置,而不是准确地达到全局最优值。然而,由于大多数神经网络模型的大量计算需求,即使这样也可能需要数天或数周的计算时间。我们将在19.4 节中探讨如何通过分布搜索或使用目标函数的评估成本更低的近似值来加快优化过程。

我们从计算模型验证误差的方法开始。

class HPOTrainer(d2l.Trainer): #@save
  def validation_error(self):
    self.model.eval()
    accuracy = 0
    val_batch_idx = 0
    for batch in self.val_dataloader:
      with torch.no_grad():
        x, y = self.prepare_batch(batch)
        y_hat = self.model(x)
        accuracy += self.model.accuracy(y_hat, y)
      val_batch_idx += 1
    return 1 - accuracy / val_batch_idx

我们优化了关于超参数配置的验证错误config,由learning_rate. 对于每个评估,我们训练我们的模型max_epochsepochs,然后计算并返回其验证错误:

def hpo_objective_softmax_classification(config, max_epochs=8):
  learning_rate = config["learning_rate"]
  trainer = d2l.HPOTrainer(max_epochs=max_epochs)
  data = d2l.FashionMNIST(batch_size=16)
  model = d2l.SoftmaxRegression(num_outputs=10, lr=learning_rate)
  trainer.fit(model=model, data=data)
  return trainer.validation_error().detach().numpy()

19.1.1.2。配置空间

随着目标函数f(x),我们还需要定义可行集x∈X优化过来,称为配置空间搜索空间对于我们的逻辑回归示例,我们将使用:

config_space = {"learning_rate": stats.loguniform(1e-4, 1)}

这里我们使用loguniformSciPy 中的对象,它表示对数空间中 -4 和 -1 之间的均匀分布。这个对象允许我们从这个分布中抽样随机变量。

每个超参数都有一个数据类型,例如floatfor learning_rate,以及一个封闭的有界范围(即下限和上限)。我们通常为每个超参数分配一个先验分布(例如,均匀分布或对数均匀分布)以从中进行采样。一些正参数(例如learning_rate)最好用对数标度表示,因为最佳值可能相差几个数量级,而其他参数(例如动量)则采用线性标度。

下面我们展示了一个配置空间的简单示例,该配置空间由多层感知器的典型超参数组成,包括它们的类型和标准范围。

表 19.1.1多层感知机配置空间示例

姓名

类型

超参数范围

对数刻度

学习率

漂浮

[10−6,10−1]

是的

批量大小

整数

[8,256]

是的

势头

 

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !

'+ '

'+ '

'+ ''+ '
'+ ''+ ''+ '
'+ ''+ '' ); $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code ==5){ $(pop_this).attr('href',"/login/index.html"); return false } if(data.code == 2){ //跳转到VIP升级页面 window.location.href="//m.obk20.com/vip/index?aid=" + webid return false } //是会员 if (data.code > 0) { $('body').append(htmlSetNormalDownload); var getWidth=$("#poplayer").width(); $("#poplayer").css("margin-left","-"+getWidth/2+"px"); $('#tips').html(data.msg) $('.download_confirm').click(function(){ $('#dialog').remove(); }) } else { var down_url = $('#vipdownload').attr('data-url'); isBindAnalysisForm(pop_this, down_url, 1) } }); }); //是否开通VIP $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code == 2 || data.code ==5){ //跳转到VIP升级页面 $('#vipdownload>span').text("开通VIP 免费下载") return false }else{ // 待续费 if(data.code == 3) { vipExpiredInfo.ifVipExpired = true vipExpiredInfo.vipExpiredDate = data.data.endoftime } $('#vipdownload .icon-vip-tips').remove() $('#vipdownload>span').text("VIP免积分下载") } }); }).on("click",".download_cancel",function(){ $('#dialog').remove(); }) var setWeixinShare={};//定义默认的微信分享信息,页面如果要自定义分享,直接更改此变量即可 if(window.navigator.userAgent.toLowerCase().match(/MicroMessenger/i) == 'micromessenger'){ var d={ title:'PyTorch教程19.1之什么是超参数优化',//标题 desc:$('[name=description]').attr("content"), //描述 imgUrl:'https://'+location.host+'/static/images/ele-logo.png',// 分享图标,默认是logo link:'',//链接 type:'',// 分享类型,music、video或link,不填默认为link dataUrl:'',//如果type是music或video,则要提供数据链接,默认为空 success:'', // 用户确认分享后执行的回调函数 cancel:''// 用户取消分享后执行的回调函数 } setWeixinShare=$.extend(d,setWeixinShare); $.ajax({ url:"//www.obk20.com/app/wechat/index.php?s=Home/ShareConfig/index", data:"share_url="+encodeURIComponent(location.href)+"&format=jsonp&domain=m", type:'get', dataType:'jsonp', success:function(res){ if(res.status!="successed"){ return false; } $.getScript('https://res.wx.qq.com/open/js/jweixin-1.0.0.js',function(result,status){ if(status!="success"){ return false; } var getWxCfg=res.data; wx.config({ //debug: true, // 开启调试模式,调用的所有api的返回值会在客户端alert出来,若要查看传入的参数,可以在pc端打开,参数信息会通过log打出,仅在pc端时才会打印。 appId:getWxCfg.appId, // 必填,公众号的唯一标识 timestamp:getWxCfg.timestamp, // 必填,生成签名的时间戳 nonceStr:getWxCfg.nonceStr, // 必填,生成签名的随机串 signature:getWxCfg.signature,// 必填,签名,见附录1 jsApiList:['onMenuShareTimeline','onMenuShareAppMessage','onMenuShareQQ','onMenuShareWeibo','onMenuShareQZone'] // 必填,需要使用的JS接口列表,所有JS接口列表见附录2 }); wx.ready(function(){ //获取“分享到朋友圈”按钮点击状态及自定义分享内容接口 wx.onMenuShareTimeline({ title: setWeixinShare.title, // 分享标题 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享给朋友”按钮点击状态及自定义分享内容接口 wx.onMenuShareAppMessage({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 type: setWeixinShare.type, // 分享类型,music、video或link,不填默认为link dataUrl: setWeixinShare.dataUrl, // 如果type是music或video,则要提供数据链接,默认为空 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ”按钮点击状态及自定义分享内容接口 wx.onMenuShareQQ({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到腾讯微博”按钮点击状态及自定义分享内容接口 wx.onMenuShareWeibo({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ空间”按钮点击状态及自定义分享内容接口 wx.onMenuShareQZone({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); }); }); } }); } function openX_ad(posterid, htmlid, width, height) { if ($(htmlid).length > 0) { var randomnumber = Math.random(); var now_url = encodeURIComponent(window.location.href); var ga = document.createElement('iframe'); ga.src = 'https://www1.elecfans.com/www/delivery/myafr.php?target=_blank&cb=' + randomnumber + '&zoneid=' + posterid+'&prefer='+now_url; ga.width = width; ga.height = height; ga.frameBorder = 0; ga.scrolling = 'no'; var s = $(htmlid).append(ga); } } openX_ad(828, '#berry-300', 300, 250);