赞
踩
本章将介绍模型剩余的部分与数据加载与训练
该网络用于获取不同缩放比的照片
- class ImagePyramide(torch.nn.Module):
- """
- Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
- """
- def __init__(self, scales, num_channels):
- super(ImagePyramide, self).__init__()
- downs = {}
- for scale in scales:
- downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
- self.downs = nn.ModuleDict(downs)
-
- def forward(self, x):
- out_dict = {}
- for scale, down_module in self.downs.items():
- out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
- return out_dict
测试代码
- scales= [1, 0.5, 0.25, 0.125]
- pyramide=ImagePyramide(scales,3)
- pyramide_source=pyramide(source)
-
- pyramide_source_list=[w2 for (w1,w2) in pyramide_source.items()]
- figure,ax=plt.subplots(2,2,figsize=(8,4))
- for i in range(2):
- for j in range(2):
- show_item=pyramide_source_list[(2*i)+j]
- ax[i,j].imshow(show_item[0].permute(1,2,0).data)
效果如下
Vgg19是一个预训练好的网络,是风格转化中用到的一个经典网络,vgg不同卷积层的网络输出的多个特征映射。使用L1损失函数或平均绝对误差比较这些特征图。这些特征图包含图像的内容,但不包含外观。然后,感知损失计算出两个图像的内容有多相似。当然,我们希望生成的图像包含驱动图像的运动
下面代码主要实现一下功能
- class Vgg19(torch.nn.Module):
- """
- Vgg19 network for perceptual loss. See Sec 3.3.
- """
- def __init__(self, requires_grad=False):
- super(Vgg19, self).__init__()
- vgg_pretrained_features = models.vgg19(pretrained=True).features
- self.slice1 = torch.nn.Sequential()
- self.slice2 = torch.nn.Sequential()
- self.slice3 = torch.nn.Sequential()
- self.slice4 = torch.nn.Sequential()
- self.slice5 = torch.nn.Sequential()
- for x in range(2):
- self.slice1.add_module(str(x), vgg_pretrained_features[x])
- for x in range(2, 7):
- self.slice2.add_module(str(x), vgg_pretrained_features[x])
- for x in range(7, 12):
- self.slice3.add_module(str(x), vgg_pretrained_features[x])
- for x in range(12, 21):
- self.slice4.add_module(str(x), vgg_pretrained_features[x])
- for x in range(21, 30):
- self.slice5.add_module(str(x), vgg_pretrained_features[x])
-
- self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
- requires_grad=False)
- self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
- requires_grad=False)
-
- if not requires_grad:
- for param in self.parameters():
- param.requires_grad = False
-
- def forward(self, X):
- #对输入进行归一化
- X = (X - self.mean) / self.std
- h_relu1 = self.slice1(X)
- h_relu2 = self.slice2(h_relu1)
- h_relu3 = self.slice3(h_relu2)
- h_relu4 = self.slice4(h_relu3)
- h_relu5 = self.slice5(h_relu4)
- out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
- return out
测试代码
vgg = Vgg19()
x_vgg=vgg(source)
感知损失的关键代码如下
- 检测配置文件中是否有感知损失的权重设定
- if sum(self.loss_weights['perceptual']) != 0:
- value_total = 0
- #循环金字塔网络输出的各种大小的图片
- for scale in self.scales:
- #生成图片的特征图
- x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
- #真实图片的特征图
- y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
- #根据权重进行加权
- for i, weight in enumerate(self.loss_weights['perceptual']):
- value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
- value_total += self.loss_weights['perceptual'][i] * value
- loss_values['perceptual'] = value_total
这里的判别器是一种不太规范的叫法,这里的判别器只是将图像与关键帧信息用来生成高斯的置信图,并加以返回
- class DownBlock2d_disc(nn.Module):
- """
- Simple block for processing video (encoder).
- """
-
- def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
- super(DownBlock2d_disc, self).__init__()
- self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
-
- if sn:
- self.conv = nn.utils.spectral_norm(self.conv)
-
- if norm:
- self.norm = nn.InstanceNorm2d(out_features, affine=True)
- else:
- self.norm = None
- self.po
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。