PyTorch教程-14.12. 神经风格迁移

电子说

1.3w人已加入

描述

如果你是摄影爱好者,你可能对滤镜并不陌生。它可以改变照片的色彩风格,使风景照片变得更清晰或肖像照片皮肤变白。但是,一个滤镜通常只会改变照片的一个方面。要为照片应用理想的风格,您可能需要尝试多种不同的滤镜组合。这个过程与调整模型的超参数一样复杂。

在本节中,我们将利用 CNN 的分层表示将一幅图像的风格自动应用到另一幅图像,即 风格迁移 (Gatys等人,2016 年)。此任务需要两张输入图像:一张是内容图像,另一张是风格图像。我们将使用神经网络修改内容图像,使其在风格上接近风格图像。例如 图14.12.1中的内容图片是我们在西雅图郊区雷尼尔山国家公园拍摄的风景照,而风格图是一幅以秋天的橡树为主题的油画。在输出的合成图像中,应用了样式图像的油画笔触,使颜色更加鲜艳,同时保留了内容图像中对象的主要形状。

pytorch

图 14.12.1给定内容和风格图像,风格迁移输出合成图像。

14.12.1。方法

图 14.12.2用一个简化的例子说明了基于 CNN 的风格迁移方法。首先,我们将合成图像初始化为内容图像。这张合成图像是风格迁移过程中唯一需要更新的变量,即训练期间要更新的模型参数。然后我们选择一个预训练的 CNN 来提取图像特征并在训练期间冻结其模型参数。这种深度 CNN 使用多层来提取图像的层次特征。我们可以选择其中一些层的输出作为内容特征或样式特征。如图14.12.2举个例子。这里的预训练神经网络有 3 个卷积层,其中第二层输出内容特征,第一层和第三层输出风格特征。

pytorch

图 14.12.2基于 CNN 的风格迁移过程。实线表示正向传播方向,虚线表示反向传播。

接下来,我们通过正向传播(实线箭头方向)计算风格迁移的损失函数,并通过反向传播(虚线箭头方向)更新模型参数(输出的合成图像)。风格迁移中常用的损失函数由三部分组成:(i)内容损失使合成图像和内容图像在内容特征上接近;(ii)风格损失使得合成图像和风格图像在风格特征上接近;(iii) 总变差损失有助于减少合成图像中的噪声。最后,当模型训练结束后,我们输出风格迁移的模型参数,生成最终的合成图像。

下面,我们将通过一个具体的实验来解释风格迁移的技术细节。

14.12.2。阅读内容和样式图像

首先,我们阅读内容和样式图像。从它们打印的坐标轴,我们可以看出这些图像具有不同的尺寸。

 

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

d2l.set_figsize()
content_img = d2l.Image.open('../img/rainier.jpg')
d2l.plt.imshow(content_img);

 

pytorch

 

style_img = d2l.Image.open('../img/autumn-oak.jpg')
d2l.plt.imshow(style_img);

 

pytorch

 

%matplotlib inline
from mxnet import autograd, gluon, image, init, np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

d2l.set_figsize()
content_img = image.imread('../img/rainier.jpg')
d2l.plt.imshow(content_img.asnumpy());

 

pytorch

 

style_img = image.imread('../img/autumn-oak.jpg')
d2l.plt.imshow(style_img.asnumpy());

 

pytorch

14.12.3。预处理和后处理

下面,我们定义了两个用于预处理和后处理图像的函数。该preprocess函数对输入图像的三个 RGB 通道中的每一个进行标准化,并将结果转换为 CNN 输入格式。该postprocess函数将输出图像中的像素值恢复为标准化前的原始值。由于图像打印功能要求每个像素都有一个从0到1的浮点值,我们将任何小于0或大于1的值分别替换为0或1。

 

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

def preprocess(img, image_shape):
  transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(image_shape),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
  return transforms(img).unsqueeze(0)

def postprocess(img):
  img = img[0].to(rgb_std.device)
  img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
  return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

 

 

rgb_mean = np.array([0.485, 0.456, 0.406])
rgb_std = np.array([0.229, 0.224, 0.225])

def preprocess(img, image_shape):
  img = image.imresize(img, *image_shape)
  img = (img.astype('float32') / 255 - rgb_mean) / rgb_std
  return np.expand_dims(img.transpose(2, 0, 1), axis=0)

def postprocess(img):
  img = img[0].as_in_ctx(rgb_std.ctx)
  return (img.transpose(1, 2, 0) * rgb_std + rgb_mean).clip(0, 1)

 

14.12.4。提取特征

我们使用在 ImageNet 数据集上预训练的 VGG-19 模型来提取图像特征( Gatys et al. , 2016 )。

 

pretrained_net = torchvision.models.vgg19(pretrained=True)

 

 

pretrained_net = gluon.model_zoo.vision.vgg19(pretrained=True)

 

为了提取图像的内容特征和风格特征,我们可以选择VGG网络中某些层的输出。一般来说,越靠近输入层越容易提取图像的细节,反之越容易提取图像的全局信息。为了避免在合成图像中过度保留内容图像的细节,我们选择了一个更接近输出的VGG层作为内容层来输出图像的内容特征。我们还选择不同 VGG 层的输出来提取局部和全局风格特征。这些图层也称为样式图层。如第 8.2 节所述,VGG 网络使用 5 个卷积块。在实验中,我们选择第四个卷积块的最后一个卷积层作为内容层,每个卷积块的第一个卷积层作为样式层。这些层的索引可以通过打印pretrained_net实例来获得。

 

style_layers, content_layers = [0, 5, 10, 19, 28], [25]

 

 

style_layers, content_layers = [0, 5, 10, 19, 28], [25]

 

当使用 VGG 层提取特征时,我们只需要使用从输入层到最接近输出层的内容层或样式层的所有那些。让我们构建一个新的网络实例net,它只保留所有用于特征提取的 VGG 层。

 

net = nn.Sequential(*[pretrained_net.features[i] for i in
           range(max(content_layers + style_layers) + 1)])

 

 

net = nn.Sequential()
for i in range(max(content_layers + style_layers) + 1):
  net.add(pretrained_net.features[i])

 

给定输入X,如果我们简单地调用前向传播 net(X),我们只能得到最后一层的输出。由于我们还需要中间层的输出,因此我们需要逐层计算并保留内容层和样式层的输出。

 

def extract_features(X, content_layers, style_layers):
  contents = []
  styles = []
  for i in range(len(net)):
    X = net[i](X)
    if i in style_layers:
      styles.append(X)
    if i in content_layers:
      contents.append(X)
  return contents, styles

 

 

def extract_features(X, content_layers, style_layers):
  contents = []
  styles = []
  for i in range(len(net)):
    X = net[i](X)
    if i in style_layers:
      styles.append(X)
    if i in content_layers:
      contents.append(X)
  return contents, styles

 

下面定义了两个函数:get_contents函数从内容图像中提取内容特征,函数get_styles从风格图像中提取风格特征。由于在训练期间不需要更新预训练 VGG 的模型参数,我们甚至可以在训练开始之前提取内容和风格特征。由于合成图像是一组需要更新的模型参数以进行风格迁移,因此我们只能extract_features 在训练时通过调用函数来提取合成图像的内容和风格特征。

 

def get_contents(image_shape, device):
  content_X = preprocess(content_img, image_shape).to(device)
  contents_Y, _ = extract_features(content_X, content_layers, style_layers)
  return content_X, contents_Y

def get_styles(image_shape, device):
  style_X = preprocess(style_img, image_shape).to(device)
  _, styles_Y = extract_features(style_X, content_layers, style_layers)
  return style_X, styles_Y

 

 

def get_contents(image_shape, device):
  content_X = preprocess(content_img, image_shape).copyto(device)
  contents_Y, _ = extract_features(content_X, content_layers, style_layers)
  return content_X, contents_Y

def get_styles(image_shape, device):
  style_X = preprocess(style_img, image_shape).copyto(device)
  _, styles_Y = extract_features(style_X, content_layers, style_layers)
  return style_X, styles_Y

 

14.12.5。定义损失函数

现在我们将描述风格迁移的损失函数。损失函数包括内容损失、风格损失和全变损失。

14.12.5.1。内容丢失

类似于线性回归中的损失函数,内容损失通过平方损失函数衡量合成图像和内容图像之间内容特征的差异。平方损失函数的两个输入都是该extract_features函数计算的内容层的输出。

 

def content_loss(Y_hat, Y):
  # We detach the target content from the tree used to dynamically compute
  # the gradient: this is a stated value, not a variable. Otherwise the loss
  # will throw an error.
  return torch.square(Y_hat - Y.detach()).mean()

 

 

def content_loss(Y_hat, Y):
  return np.square(Y_hat - Y).mean()

 

14.12.5.2。风格损失

风格损失与内容损失类似,也是使用平方损失函数来衡量合成图像与风格图像之间的风格差异。为了表达任何样式层的样式输出,我们首先使用函数extract_features来计算样式层输出。假设输出有 1 个示例,c渠道,高度 h, 和宽度w,我们可以将此输出转换为矩阵 X和c行和hw列。这个矩阵可以被认为是串联c载体 x1,…,xc, 其中每一个的长度为hw. 在这里,矢量xi表示频道的风格特征i.

在这些向量的 Gram 矩阵中XX⊤∈Rc×c, 元素 xij在排队i和专栏j是向量的点积xi和xj. 表示渠道风格特征的相关性i和 j. 我们使用这个 Gram 矩阵来表示任何样式层的样式输出。请注意,当值hw越大,它可能会导致 Gram 矩阵中的值越大。还要注意,Gram矩阵的高和宽都是通道数c. 为了让风格损失不受这些值的影响,gram 下面的函数将 Gram 矩阵除以其元素的数量,即chw.

 

def gram(X):
  num_channels, n = X.shape[1], X.numel() // X.shape[1]
  X = X.reshape((num_channels, n))
  return torch.matmul(X, X.T) / (num_channels * n)

 

 

def gram(X):
  num_channels, n = X.shape[1], d2l.size(X) // X.shape[1]
  X = X.reshape((num_channels, n))
  return np.dot(X, X.T) / (num_channels * n)

 

显然,风格损失的平方损失函数的两个格拉姆矩阵输入是基于合成图像和风格图像的风格层输出。这里假设 gram_Y基于风格图像的 Gram 矩阵已经预先计算好了。

 

def style_loss(Y_hat, gram_Y):
  return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

 

 

def style_loss(Y_hat, gram_Y):
  return np.square(gram(Y_hat) - gram_Y).mean()

 

14.12.5.3。总变异损失

有时,学习到的合成图像有很多高频噪声,即特别亮或特别暗的像素。一种常见的降噪方法是全变差去噪。表示为 xi,j坐标处的像素值(i,j). 减少总变异损失

(14.12.1)∑i,j|xi,j−xi+1,j|+|xi,j−xi,j+1|

使合成图像上相邻像素的值更接近。

 

def tv_loss(Y_hat):
  return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
         torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

 

 

def tv_loss(Y_hat):
  return 0.5 * (np.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
         np.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

 

14.12.5.4。损失函数

风格迁移的损失函数是内容损失、风格损失和总变异损失的加权和。通过调整这些权重超参数,我们可以在合成图像的内容保留、风格迁移和降噪之间取得平衡。

 

content_weight, style_weight, tv_weight = 1, 1e4, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
  # Calculate the content, style, and total variance losses respectively
  contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
    contents_Y_hat, contents_Y)]
  styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
    styles_Y_hat, styles_Y_gram)]
  tv_l = tv_loss(X) * tv_weight
  # Add up all the losses
  l = sum(styles_l + contents_l + [tv_l])
  return contents_l, styles_l, tv_l, l

 

 

content_weight, style_weight, tv_weight = 1, 1e4, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
  # Calculate the content, style, and total variance losses respectively
  contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
    contents_Y_hat, contents_Y)]
  styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
    styles_Y_hat, styles_Y_gram)]
  tv_l = tv_loss(X) * tv_weight
  # Add up all the losses
  l = sum(styles_l + contents_l + [tv_l])
  return contents_l, styles_l, tv_l, l

 

14.12.6. 初始化合成图像

在风格迁移中,合成图像是训练期间唯一需要更新的变量。因此,我们可以定义一个简单的模型, SynthesizedImage并将合成图像作为模型参数。在这个模型中,前向传播只返回模型参数。

 

class SynthesizedImage(nn.Module):
  def __init__(self, img_shape, **kwargs):
    super(SynthesizedImage, self).__init__(**kwargs)
    self.weight = nn.Parameter(torch.rand(*img_shape))

  def forward(self):
    return self.weight

 

 

class SynthesizedImage(nn.Block):
  def __init__(self, img_shape, **kwargs):
    super(SynthesizedImage, self).__init__(**kwargs)
    self.weight = self.params.get('weight', shape=img_shape)

  def forward(self):
    return self.weight.data()

 

接下来,我们定义get_inits函数。此函数创建一个合成图像模型实例并将其初始化为 image X。styles_Y_gram在训练之前计算各种样式层的样式图像的 Gram 矩阵 。

 

def get_inits(X, device, lr, styles_Y):
  gen_img = SynthesizedImage(X.shape).to(device)
  gen_img.weight.data.copy_(X.data)
  trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
  styles_Y_gram = [gram(Y) for Y in styles_Y]
  return gen_img(), styles_Y_gram, trainer

 

 

def get_inits(X, device, lr, styles_Y):
  gen_img = SynthesizedImage(X.shape)
  gen_img.initialize(init.Constant(X), ctx=device, force_reinit=True)
  trainer = gluon.Trainer(gen_img.collect_params(), 'adam',
              {'learning_rate': lr})
  styles_Y_gram = [gram(Y) for Y in styles_Y]
  return gen_img(), styles_Y_gram, trainer

 

14.12.7. 训练

在训练风格迁移模型时,我们不断提取合成图像的内容特征和风格特征,并计算损失函数。下面定义了训练循环。

 

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
  X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
  scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
  animator = d2l.Animator(xlabel='epoch', ylabel='loss',
              xlim=[10, num_epochs],
              legend=['content', 'style', 'TV'],
              ncols=2, figsize=(7, 2.5))
  for epoch in range(num_epochs):
    trainer.zero_grad()
    contents_Y_hat, styles_Y_hat = extract_features(
      X, content_layers, style_layers)
    contents_l, styles_l, tv_l, l = compute_loss(
      X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
    l.backward()
    trainer.step()
    scheduler.step()
    if (epoch + 1) % 10 == 0:
      animator.axes[1].imshow(postprocess(X))
      animator.add(epoch + 1, [float(sum(contents_l)),
                   float(sum(styles_l)), float(tv_l)])
  return X

 

 

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
  X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
  animator = d2l.Animator(xlabel='epoch', ylabel='loss',
              xlim=[10, num_epochs], ylim=[0, 20],
              legend=['content', 'style', 'TV'],
              ncols=2, figsize=(7, 2.5))
  for epoch in range(num_epochs):
    with autograd.record():
      contents_Y_hat, styles_Y_hat = extract_features(
        X, content_layers, style_layers)
      contents_l, styles_l, tv_l, l = compute_loss(
        X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
    l.backward()
    trainer.step(1)
    if (epoch + 1) % lr_decay_epoch == 0:
      trainer.set_learning_rate(trainer.learning_rate * 0.8)
    if (epoch + 1) % 10 == 0:
      animator.axes[1].imshow(postprocess(X).asnumpy())
      animator.add(epoch + 1, [float(sum(contents_l)),
                   float(sum(styles_l)), float(tv_l)])
  return X

 

现在我们开始训练模型。我们将内容和样式图像的高度和宽度重新调整为 300 x 450 像素。我们使用内容图像来初始化合成图像。

 

device, image_shape = d2l.try_gpu(), (300, 450) # PIL Image (h, w)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

 

pytorch

 

device, image_shape = d2l.try_gpu(), (450, 300)
net.collect_params().reset_ctx(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.9, 500, 50)

 

pytorch

我们可以看到,合成图保留了内容图的景物和物体,同时传递了风格图的颜色。例如,合成图像具有风格图像中的颜色块。其中一些块甚至具有笔触的微妙纹理。

14.12.8。概括

风格迁移中常用的损失函数由三部分组成:(i)内容损失使合成图像和内容图像在内容特征上接近;(ii) 风格损失使得合成图像和风格图像在风格特征上接近;(iii) 总变差损失有助于减少合成图像中的噪声。

我们可以使用预训练的 CNN 提取图像特征并最小化损失函数,以在训练期间不断更新合成图像作为模型参数。

我们使用 Gram 矩阵来表示样式层的样式输出。

14.12.9。练习

当您选择不同的内容和样式层时,输出如何变化?

调整损失函数中的权重超参数。输出是保留更多内容还是噪音更少?

使用不同的内容和样式图像。你能创造出更有趣的合成图像吗?

我们可以对文本应用样式转换吗?提示:您可以参考Hu等人的调查论文。( 2022 )。

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分