×

PyTorch教程23.7之效用函数和类

消耗积分:0 | 格式:pdf | 大小:0.18 MB | 2023-06-05

分享资料个

本节包含本书中使用的实用函数和类的实现。

import collections
import inspect
from IPython import display
from torch import nn
from d2l import torch as d2l
import collections
import inspect
import random
from IPython import display
from mxnet import autograd, gluon, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()
import collections
import inspect
import jax
from IPython import display
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import collections
import inspect
import tensorflow as tf
from IPython import display
from d2l import tensorflow as d2l

超参数。

@d2l.add_to_class(d2l.HyperParameters) #@save
def save_hyperparameters(self, ignore=[]):
  """Save function arguments into class attributes."""
  frame = inspect.currentframe().f_back
  _, _, _, local_vars = inspect.getargvalues(frame)
  self.hparams = {k:v for k, v in local_vars.items()
          if k not in set(ignore+['self']) and not k.startswith('_')}
  for k, v in self.hparams.items():
    setattr(self, k, v)

进度条。

@d2l.add_to_class(d2l.ProgressBoard) #@save
def draw(self, x, y, label, every_n=1):
  Point = collections.namedtuple('Point', ['x', 'y'])
  if not hasattr(self, 'raw_points'):
    self.raw_points = collections.OrderedDict()
    self.data = collections.OrderedDict()
  if label not in self.raw_points:
    self.raw_points[label] = []
    self.data[label] = []
  points = self.raw_points[label]
  line = self.data[label]
  points.append(Point(x, y))
  if len(points) != every_n:
    return
  mean = lambda x: sum(x) / len(x)
  line.append(Point(mean([p.x for p in points]),
           mean([p.y for p in points])))
  points.clear()
  if not self.display:
    return
  d2l.use_svg_display()
  if self.fig is None:
    self.fig = d2l.plt.figure(figsize=self.figsize)
  plt_lines, labels = [], []
  for (k, v), ls, color in zip(self.data.items(), self.ls, self.colors):
    plt_lines.append(d2l.plt.plot([p.x for p in v], [p.y for p in v],
                   linestyle=ls, color=color)[0])
    labels.append(k)
  axes = self.axes if self.axes else d2l.plt.gca()
  if self.xlim: axes.set_xlim(self.xlim)
  if self.ylim: axes.set_ylim(self.ylim)
  if not self.xlabel: self.xlabel = self.x
  axes.set_xlabel(self.xlabel)
  axes.set_ylabel(self.ylabel)
  axes.set_xscale(self.xscale)
  axes.set_yscale(self.yscale)
  axes.legend(plt_lines, labels)
  display.display(self.fig)
  display.clear_output(wait=True)

添加 FrozenLake 环境

def frozen_lake(seed): #@save
  # See https://www.gymlibrary.dev/environments/toy_text/frozen_lake/ to learn more about this env
  # How to process env.P.items is adpated from https://sites.google.com/view/deep-rl-bootcamp/labs

  env = gym.make('FrozenLake-v1', is_slippery=False)
  env.seed(seed)
  env.action_space.np_random.seed(seed)
  env.action_space.seed(seed)
  env_info = {}
  env_info['desc'] = env.desc # 2D array specifying what each grid item means
  env_info['num_states'] = env.nS # Number of observations/states or obs/state dim
  env_info['num_actions'] = env.nA # Number of actions or action dim
  # Define indices for (transition probability, nextstate, reward, done) tuple
  env_info['trans_prob_idx'] = 0 # Index of transition probability entry
  env_info['nextstate_idx'] = 1 # Index of next state entry
  env_info['reward_idx'] = 2 # Index of reward entry
  env_info['done_idx'] = 3 # Index of done entry
  env_info['mdp'] = {}
  env_info['env'] = env

  for (s, others) in env.P.items():
    # others(s) = {a0: [ (p(s'|s,a0), s', reward, done),...], a1:[...], ...}

    for (a, pxrds) in others.items():
      # pxrds is [(p1,next1,r1,d1),(p2,next2,r2,d2),..].
      # e.g. [(0.3, 0, 0, False), (0.3, 0, 0, False), (0.3, 4, 1, False)]
      env_info['mdp'][(s,a)] = pxrds

  return env_info

创造环境

def make_env(name ='', seed=0): #@save
  # Input parameters:
  # name: specifies a gym environment.
  # For Value iteration, only FrozenLake-v1 is supported.
  if name == 'FrozenLake-v1':
    return frozen_lake(seed)

  else:
    raise ValueError("%s env is not supported in this Notebook")

示值函数

def show_value_function_progress(env_desc, V, pi): #@save
  # This function visualizes how value and policy changes over time.
  # V: [num_iters, num_states]
  # pi: [num_iters, num_states]
  # How to visualize value function is adapted (but changed) from: https://sites.google.com/view/deep-rl-bootcamp/labs

  num_iters = V.shape[0]
  fig, ax = plt.subplots(figsize=(15, 15))

  for k in range(V.shape[0]):
    plt.subplot(4, 4, k + 1)
    plt.imshow(V[k].reshape(4,4), cmap="bone")
    ax = plt.gca()
    ax.set_xticks(np.arange(0, 5)-.5, minor=True)
    ax.set_yticks(np.arange(0, 5)-.5, minor=True)
    ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
    ax.tick_params(which="minor", bottom=False, left=False)
    ax.set_xticks([])
    ax.set_yticks([])

    # LEFT action: 0, DOWN action: 1
    # RIGHT action: 2, UP action: 3
    action2dxdy = {0:(-.25, 0),1: (0, .25),
            2:(0.25, 0),3: (-.25, 0)}

    for y in range(4):
      for x in range(4):
        action = pi[k].reshape(4,4)[y, x]
        dx, dy = action2dxdy[action]

        if env_desc[y,x].decode() == 'H':
          ax.text(x, y, str(env_desc[y,x].decode()),
            ha="center", va="center", color="y",
             size=20, fontweight='bold')

        elif env_desc[y,x].decode() == 'G':
          ax.text(x, y, str(env_desc[y,x].decode()),
            ha="center", va="center", color="w",
             size=20, fontweight='bold')

        else:
          ax.text(x, y, str(env_desc[y,x].decode()),
            ha="center", va="center", color="g",
             size=15, fontweight='bold')

        # No arrow for cells with G and H labels
        if env_desc[y,x].decode() != 'G' and env_desc[y,x].decode() != 'H':
          ax.arrow(x, y, dx, dy, color='r', head_width=0.2, head_length=0.15)

    ax.set_title("Step = " + str(k + 1),

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !

'+ '

'+ '

'+ ''+ '
'+ ''+ ''+ '
'+ ''+ '' ); $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code ==5){ $(pop_this).attr('href',"/login/index.html"); return false } if(data.code == 2){ //跳转到VIP升级页面 window.location.href="//m.obk20.com/vip/index?aid=" + webid return false } //是会员 if (data.code > 0) { $('body').append(htmlSetNormalDownload); var getWidth=$("#poplayer").width(); $("#poplayer").css("margin-left","-"+getWidth/2+"px"); $('#tips').html(data.msg) $('.download_confirm').click(function(){ $('#dialog').remove(); }) } else { var down_url = $('#vipdownload').attr('data-url'); isBindAnalysisForm(pop_this, down_url, 1) } }); }); //是否开通VIP $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code == 2 || data.code ==5){ //跳转到VIP升级页面 $('#vipdownload>span').text("开通VIP 免费下载") return false }else{ // 待续费 if(data.code == 3) { vipExpiredInfo.ifVipExpired = true vipExpiredInfo.vipExpiredDate = data.data.endoftime } $('#vipdownload .icon-vip-tips').remove() $('#vipdownload>span').text("VIP免积分下载") } }); }).on("click",".download_cancel",function(){ $('#dialog').remove(); }) var setWeixinShare={};//定义默认的微信分享信息,页面如果要自定义分享,直接更改此变量即可 if(window.navigator.userAgent.toLowerCase().match(/MicroMessenger/i) == 'micromessenger'){ var d={ title:'PyTorch教程23.7之效用函数和类',//标题 desc:$('[name=description]').attr("content"), //描述 imgUrl:'https://'+location.host+'/static/images/ele-logo.png',// 分享图标,默认是logo link:'',//链接 type:'',// 分享类型,music、video或link,不填默认为link dataUrl:'',//如果type是music或video,则要提供数据链接,默认为空 success:'', // 用户确认分享后执行的回调函数 cancel:''// 用户取消分享后执行的回调函数 } setWeixinShare=$.extend(d,setWeixinShare); $.ajax({ url:"//www.obk20.com/app/wechat/index.php?s=Home/ShareConfig/index", data:"share_url="+encodeURIComponent(location.href)+"&format=jsonp&domain=m", type:'get', dataType:'jsonp', success:function(res){ if(res.status!="successed"){ return false; } $.getScript('https://res.wx.qq.com/open/js/jweixin-1.0.0.js',function(result,status){ if(status!="success"){ return false; } var getWxCfg=res.data; wx.config({ //debug: true, // 开启调试模式,调用的所有api的返回值会在客户端alert出来,若要查看传入的参数,可以在pc端打开,参数信息会通过log打出,仅在pc端时才会打印。 appId:getWxCfg.appId, // 必填,公众号的唯一标识 timestamp:getWxCfg.timestamp, // 必填,生成签名的时间戳 nonceStr:getWxCfg.nonceStr, // 必填,生成签名的随机串 signature:getWxCfg.signature,// 必填,签名,见附录1 jsApiList:['onMenuShareTimeline','onMenuShareAppMessage','onMenuShareQQ','onMenuShareWeibo','onMenuShareQZone'] // 必填,需要使用的JS接口列表,所有JS接口列表见附录2 }); wx.ready(function(){ //获取“分享到朋友圈”按钮点击状态及自定义分享内容接口 wx.onMenuShareTimeline({ title: setWeixinShare.title, // 分享标题 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享给朋友”按钮点击状态及自定义分享内容接口 wx.onMenuShareAppMessage({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 type: setWeixinShare.type, // 分享类型,music、video或link,不填默认为link dataUrl: setWeixinShare.dataUrl, // 如果type是music或video,则要提供数据链接,默认为空 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ”按钮点击状态及自定义分享内容接口 wx.onMenuShareQQ({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到腾讯微博”按钮点击状态及自定义分享内容接口 wx.onMenuShareWeibo({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ空间”按钮点击状态及自定义分享内容接口 wx.onMenuShareQZone({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); }); }); } }); } function openX_ad(posterid, htmlid, width, height) { if ($(htmlid).length > 0) { var randomnumber = Math.random(); var now_url = encodeURIComponent(window.location.href); var ga = document.createElement('iframe'); ga.src = 'https://www1.elecfans.com/www/delivery/myafr.php?target=_blank&cb=' + randomnumber + '&zoneid=' + posterid+'&prefer='+now_url; ga.width = width; ga.height = height; ga.frameBorder = 0; ga.scrolling = 'no'; var s = $(htmlid).append(ga); } } openX_ad(828, '#berry-300', 300, 250);