当前位置:   article > 正文

让您的照片动起来first motion model(4)-对抗生成网络与模型训练_random tps transformation for equivariance constra

random tps transformation for equivariance constraints.

1、概述

本章将介绍模型剩余的部分与数据加载与训练

2、GeneratorFullModel完整的生成器

2.1 金字塔网络(ImagePyramide)

该网络用于获取不同缩放比的照片

  1. class ImagePyramide(torch.nn.Module):
  2. """
  3. Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
  4. """
  5. def __init__(self, scales, num_channels):
  6. super(ImagePyramide, self).__init__()
  7. downs = {}
  8. for scale in scales:
  9. downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
  10. self.downs = nn.ModuleDict(downs)
  11. def forward(self, x):
  12. out_dict = {}
  13. for scale, down_module in self.downs.items():
  14. out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
  15. return out_dict

测试代码

  1. scales= [1, 0.5, 0.25, 0.125]
  2. pyramide=ImagePyramide(scales,3)
  3. pyramide_source=pyramide(source)
  4. pyramide_source_list=[w2 for (w1,w2) in pyramide_source.items()]
  5. figure,ax=plt.subplots(2,2,figsize=(8,4))
  6. for i in range(2):
  7. for j in range(2):
  8. show_item=pyramide_source_list[(2*i)+j]
  9. ax[i,j].imshow(show_item[0].permute(1,2,0).data)

效果如下

2.2 Vgg19网络与感知损失(perceptual loss)

Vgg19是一个预训练好的网络,是风格转化中用到的一个经典网络,vgg不同卷积层的网络输出的多个特征映射。使用L1损失函数或平均绝对误差比较这些特征图。这些特征图包含图像的内容,但不包含外观。然后,感知损失计算出两个图像的内容有多相似。当然,我们希望生成的图像包含驱动图像的运动

下面代码主要实现一下功能

  1. 将输入进行按照指定的均值与方差进行normalize操作
  2. 取出vgg网络的第2,7,12,30层的特征输出并返回
  1. class Vgg19(torch.nn.Module):
  2. """
  3. Vgg19 network for perceptual loss. See Sec 3.3.
  4. """
  5. def __init__(self, requires_grad=False):
  6. super(Vgg19, self).__init__()
  7. vgg_pretrained_features = models.vgg19(pretrained=True).features
  8. self.slice1 = torch.nn.Sequential()
  9. self.slice2 = torch.nn.Sequential()
  10. self.slice3 = torch.nn.Sequential()
  11. self.slice4 = torch.nn.Sequential()
  12. self.slice5 = torch.nn.Sequential()
  13. for x in range(2):
  14. self.slice1.add_module(str(x), vgg_pretrained_features[x])
  15. for x in range(2, 7):
  16. self.slice2.add_module(str(x), vgg_pretrained_features[x])
  17. for x in range(7, 12):
  18. self.slice3.add_module(str(x), vgg_pretrained_features[x])
  19. for x in range(12, 21):
  20. self.slice4.add_module(str(x), vgg_pretrained_features[x])
  21. for x in range(21, 30):
  22. self.slice5.add_module(str(x), vgg_pretrained_features[x])
  23. self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
  24. requires_grad=False)
  25. self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
  26. requires_grad=False)
  27. if not requires_grad:
  28. for param in self.parameters():
  29. param.requires_grad = False
  30. def forward(self, X):
  31. #对输入进行归一化
  32. X = (X - self.mean) / self.std
  33. h_relu1 = self.slice1(X)
  34. h_relu2 = self.slice2(h_relu1)
  35. h_relu3 = self.slice3(h_relu2)
  36. h_relu4 = self.slice4(h_relu3)
  37. h_relu5 = self.slice5(h_relu4)
  38. out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
  39. return out

测试代码

vgg = Vgg19()
x_vgg=vgg(source)

感知损失的关键代码如下

  1. 检测配置文件中是否有感知损失的权重设定
  2. if sum(self.loss_weights['perceptual']) != 0:
  3. value_total = 0
  4. #循环金字塔网络输出的各种大小的图片
  5. for scale in self.scales:
  6. #生成图片的特征图
  7. x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
  8. #真实图片的特征图
  9. y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
  10. #根据权重进行加权
  11. for i, weight in enumerate(self.loss_weights['perceptual']):
  12. value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
  13. value_total += self.loss_weights['perceptual'][i] * value
  14. loss_values['perceptual'] = value_total

2.3 判别器(discriminator)

这里的判别器是一种不太规范的叫法,这里的判别器只是将图像与关键帧信息用来生成高斯的置信图,并加以返回

  1. class DownBlock2d_disc(nn.Module):
  2. """
  3. Simple block for processing video (encoder).
  4. """
  5. def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
  6. super(DownBlock2d_disc, self).__init__()
  7. self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
  8. if sn:
  9. self.conv = nn.utils.spectral_norm(self.conv)
  10. if norm:
  11. self.norm = nn.InstanceNorm2d(out_features, affine=True)
  12. else:
  13. self.norm = None
  14. self.po
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号