×

PyTorch教程19.2之超参数优化API

消耗积分:0 | 格式:pdf | 大小:0.26 MB | 2023-06-05

分享资料个

在我们深入研究该方法之前,我们将首先讨论一个基本的代码结构,它使我们能够有效地实现各种 HPO 算法。一般来说,这里考虑的所有 HPO 算法都需要实现两个决策原语,即搜索调度首先,他们需要对新的超参数配置进行采样,这通常涉及对配置空间的某种搜索。其次,对于每个配置,HPO 算法需要安排其评估并决定为其分配多少资源。一旦我们开始评估配置,我们就会将其称为试用我们将这些决定映射到两个类,HPOSearcherHPOScheduler除此之外,我们还提供HPOTuner执行优化过程的类。

这种调度器和搜索器的概念也在流行的 HPO 库中实现,例如 Syne Tune Salinas等人,2022 年、Ray Tune Liaw等人,2018 年或 Optuna Akiba等人,2019 年

import time
from scipy import stats
from d2l import torch as d2l

19.2.1。搜寻器

下面我们定义一个搜索器的基类,通过函数提供一个新的候选配置sample_configuration实现此功能的一种简单方法是随机对配置进行统一采样,就像我们在 第 19.1 节中对随机搜索所做的那样。更复杂的算法,例如贝叶斯优化,将根据先前试验的表现做出这些决定。因此,随着时间的推移,这些算法能够对更有希望的候选人进行抽样。我们添加该update 功能是为了更新以前试验的历史,然后可以利用它来改进我们的抽样分布。

class HPOSearcher(d2l.HyperParameters): #@save
  def sample_configuration() -> dict:
    raise NotImplementedError

  def update(self, config: dict, error: float, additional_info=None):
    pass

以下代码显示了如何在此 API 中实现我们上一节中的随机搜索优化器。作为一个轻微的扩展,我们允许用户通过 指定要评估的第一个配置 initial_config,而随后的配置是随机抽取的。

class RandomSearcher(HPOSearcher): #@save
  def __init__(self, config_space: dict, initial_config=None):
    self.save_hyperparameters()

  def sample_configuration(self) -> dict:
    if self.initial_config is not None:
      result = self.initial_config
      self.initial_config = None
    else:
      result = {
        name: domain.rvs()
        for name, domain in self.config_space.items()
      }
    return result

19.2.2。调度程序

除了新试验的采样配置外,我们还需要决定何时进行试验以及进行多长时间。实际上,所有这些决定都是由 完成的HPOScheduler,它将新配置的选择委托给HPOSearcher. suggest只要某些训练资源可用,就会调用该方法。除了调用sample_configuration搜索器之外,它还可以决定诸如max_epochs(即训练模型的时间)之类的参数。update每当试验返回新观察时调用该方法。

class HPOScheduler(d2l.HyperParameters): #@save
  def suggest(self) -> dict:
    raise NotImplementedError

  def update(self, config: dict, error: float, info=None):
    raise NotImplementedError

要实现随机搜索以及其他 HPO 算法,我们只需要一个基本的调度程序,它可以在每次新资源可用时调度新的配置。

class BasicScheduler(HPOScheduler): #@save
  def __init__(self, searcher: HPOSearcher):
    self.save_hyperparameters()

  def suggest(self) -> dict:
    return self.searcher.sample_configuration()

  def update(self, config: dict, error: float, info=None):
    self.searcher.update(config, error, additional_info=info)

19.2.3。调谐器

最后,我们需要一个组件来运行调度器/搜索器并对结果进行一些簿记。下面的代码实现了 HPO 试验的顺序执行,在下一个训练作业之后评估一个训练作业,并将作为一个基本示例。我们稍后将使用 Syne Tune来处理更具可扩展性的分布式 HPO 案例。

class HPOTuner(d2l.HyperParameters): #@save
  def __init__(self, scheduler: HPOScheduler, objective: callable):
    self.save_hyperparameters()
    # Bookeeping results for plotting
    self.incumbent = None
    self.incumbent_error = None
    self.incumbent_trajectory = []
    self.cumulative_runtime = []
    self.current_runtime = 0
    self.records = []

  def run(self, number_of_trials):
    for i in range(number_of_trials):
      start_time = time.time()
      config = self.scheduler.suggest()
      print(f"Trial {i}: config = {config}")
      error = self.objective(**config)
      error = float(error.cpu().detach().numpy())
      self.scheduler.update(config, error)
      runtime = time.time() - start_time
      self.bookkeeping(config, error, runtime)
      print(f"  error = {error}, runtime = {runtime}")

19.2.4。簿记 HPO 算法的性能

对于任何 HPO 算法,我们最感兴趣的是性能最佳的配置(称为incumbent)及其在给定挂钟时间后的验证错误。这就是我们跟踪runtime每次迭代的原因,其中包括运行评估的时间(调用 objective)和做出决策的时间(调用 scheduler.suggest)。在续集中,我们将绘制 cumulative_runtimeagainstincumbent_trajectory以可视化根据( 和) 定义的 HPO 算法的任何时间性能这使我们不仅可以量化优化器找到的配置的工作情况,还可以量化优化器找到它的速度。schedulersearcher

@d2l.add_to_class(HPOTuner) #@save
def bookkeeping(self, config: dict, error: float, runtime: float):
  self.records.append({"config": config, "error": error, "runtime": runtime})
  # Check if the last hyperparameter configuration performs better
  # than the incumbent
  if self.incumbent is None or self.incumbent_error > error:
    self.incumbent = config
    self.incumbent_error = error
  # Add current best observed performance to the optimization trajectory
  self.incumbent_trajectory.append(self.incumbent_error)
  # Update runtime
  self.current_runtime += runtime
  self.cumulative_runtime.append(self.current_runtime)

19.2.5。示例:优化卷积神经网络的超参数

我们现在使用随机搜索的新实现来优化 第 7.6 节中卷积神经网络批量大小学习率我们通过定义目标函数,这将再次成为验证错误。LeNet

def hpo_objective_lenet(learning_rate, batch_size, max_epochs=10): #@save
  model = d2l.LeNet(lr=learning_rate, num_classes=10)
  trainer = d2l.HPOTrainer(max_epochs=max_epochs, num_gpus=1)
  data = d2l.FashionMNIST(batch_size=batch_size)
  model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn)
  trainer.fit(model=model, data=data)
  validation_error = trainer.validation_error()
  return validation_error

我们还需要定义配置空间。此外,要评估的第一个配置是 第 7.6 节中使用的默认设置。

config_space = {
  "learning_rate": stats.loguniform(1e-2, 1),
  "batch_size": stats.randint(32, 256),
}
initial_config = {
  "learning_rate": 0.1,
  "batch_size": 128,
}

现在我们可以开始随机搜索了:

searcher = RandomSearcher(config_space, initial_config=initial_config)
scheduler = BasicScheduler(searcher=searcher)
tuner = HPOTuner(scheduler=scheduler, objective=hpo_objective_lenet)
tuner.run(number_of_trials=5)
  error = 0.17130666971206665, runtime = 125.33143877983093
https://file.elecfans.com/web2/M00/AA/48/pYYBAGR9PVuAO21vAAF9e-RRQjc464.svg
https://file.elecfans.com/web2/M00/A9/CE/poYBAGR9PV2ARgCBAAF-SCs89bw491.svg
https://file.elecfans.com/web2/M00/AA/48/pYYBAGR9PV-AcPXNAAF_kYZ_xQw068.svg
https://file.elecfans.com/web2/M00/A9/CE/poYBAGR9PWGAWTGGAAF_K5I3kmI689.svg
https://file.elecfans.com/web2/M00/AA/48/pYYBAGR9PWSAQ2zDAAFzyB-zwLc643.svg

下面我们绘制了现任者的优化轨迹,以获得随机搜索的任何时间性能:

board = d2l.ProgressBoard(xlabel="time", ylabel="error")
for time_stamp, error in zip(
  tuner.cumulative_runtime, tuner.incumbent_trajectory
):
  board.draw(time_stamp, error, "random search", every_n=1)
https://file.elecfans.com/web2/M00/AA/48/pYYBAGR9PWaAXsTGAAD-6d95H6c198.svg

19.2.6. 比较 HPO 算法

正如训练算法或模型架构一样,了解如何最好地比较不同的 HPO 算法非常重要。每次 HPO 运行取决于随机性的两个主要来源:训练过程的随机效应,例如随机权重初始化或小批量排序,以及 HPO 算法本身的内在随机性,例如随机搜索的随机抽样。因此,在比较不同的算法时,至关重要的是多次运行每个实验并报告基于随机数生成器的不同种子的算法多次重复的总体统计数据,例如平均值或中值。

为了说明这一点,我们比较随机搜索(参见第 19.1.2 节)和贝叶斯优化Snoek等人,2012 年在调整前馈神经网络的超参数方面的作用。每个算法都经过评估

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !

'+ '

'+ '

'+ ''+ '
'+ ''+ ''+ '
'+ ''+ '' ); $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code ==5){ $(pop_this).attr('href',"/login/index.html"); return false } if(data.code == 2){ //跳转到VIP升级页面 window.location.href="//m.obk20.com/vip/index?aid=" + webid return false } //是会员 if (data.code > 0) { $('body').append(htmlSetNormalDownload); var getWidth=$("#poplayer").width(); $("#poplayer").css("margin-left","-"+getWidth/2+"px"); $('#tips').html(data.msg) $('.download_confirm').click(function(){ $('#dialog').remove(); }) } else { var down_url = $('#vipdownload').attr('data-url'); isBindAnalysisForm(pop_this, down_url, 1) } }); }); //是否开通VIP $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code == 2 || data.code ==5){ //跳转到VIP升级页面 $('#vipdownload>span').text("开通VIP 免费下载") return false }else{ // 待续费 if(data.code == 3) { vipExpiredInfo.ifVipExpired = true vipExpiredInfo.vipExpiredDate = data.data.endoftime } $('#vipdownload .icon-vip-tips').remove() $('#vipdownload>span').text("VIP免积分下载") } }); }).on("click",".download_cancel",function(){ $('#dialog').remove(); }) var setWeixinShare={};//定义默认的微信分享信息,页面如果要自定义分享,直接更改此变量即可 if(window.navigator.userAgent.toLowerCase().match(/MicroMessenger/i) == 'micromessenger'){ var d={ title:'PyTorch教程19.2之超参数优化API',//标题 desc:$('[name=description]').attr("content"), //描述 imgUrl:'https://'+location.host+'/static/images/ele-logo.png',// 分享图标,默认是logo link:'',//链接 type:'',// 分享类型,music、video或link,不填默认为link dataUrl:'',//如果type是music或video,则要提供数据链接,默认为空 success:'', // 用户确认分享后执行的回调函数 cancel:''// 用户取消分享后执行的回调函数 } setWeixinShare=$.extend(d,setWeixinShare); $.ajax({ url:"//www.obk20.com/app/wechat/index.php?s=Home/ShareConfig/index", data:"share_url="+encodeURIComponent(location.href)+"&format=jsonp&domain=m", type:'get', dataType:'jsonp', success:function(res){ if(res.status!="successed"){ return false; } $.getScript('https://res.wx.qq.com/open/js/jweixin-1.0.0.js',function(result,status){ if(status!="success"){ return false; } var getWxCfg=res.data; wx.config({ //debug: true, // 开启调试模式,调用的所有api的返回值会在客户端alert出来,若要查看传入的参数,可以在pc端打开,参数信息会通过log打出,仅在pc端时才会打印。 appId:getWxCfg.appId, // 必填,公众号的唯一标识 timestamp:getWxCfg.timestamp, // 必填,生成签名的时间戳 nonceStr:getWxCfg.nonceStr, // 必填,生成签名的随机串 signature:getWxCfg.signature,// 必填,签名,见附录1 jsApiList:['onMenuShareTimeline','onMenuShareAppMessage','onMenuShareQQ','onMenuShareWeibo','onMenuShareQZone'] // 必填,需要使用的JS接口列表,所有JS接口列表见附录2 }); wx.ready(function(){ //获取“分享到朋友圈”按钮点击状态及自定义分享内容接口 wx.onMenuShareTimeline({ title: setWeixinShare.title, // 分享标题 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享给朋友”按钮点击状态及自定义分享内容接口 wx.onMenuShareAppMessage({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 type: setWeixinShare.type, // 分享类型,music、video或link,不填默认为link dataUrl: setWeixinShare.dataUrl, // 如果type是music或video,则要提供数据链接,默认为空 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ”按钮点击状态及自定义分享内容接口 wx.onMenuShareQQ({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到腾讯微博”按钮点击状态及自定义分享内容接口 wx.onMenuShareWeibo({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ空间”按钮点击状态及自定义分享内容接口 wx.onMenuShareQZone({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); }); }); } }); } function openX_ad(posterid, htmlid, width, height) { if ($(htmlid).length > 0) { var randomnumber = Math.random(); var now_url = encodeURIComponent(window.location.href); var ga = document.createElement('iframe'); ga.src = 'https://www1.elecfans.com/www/delivery/myafr.php?target=_blank&cb=' + randomnumber + '&zoneid=' + posterid+'&prefer='+now_url; ga.width = width; ga.height = height; ga.frameBorder = 0; ga.scrolling = 'no'; var s = $(htmlid).append(ga); } } openX_ad(828, '#berry-300', 300, 250);