当前位置:   article > 正文

图像风格迁移——pytorch实现_style_img = style_img.resize(content_img.size)

style_img = style_img.resize(content_img.size)

一、风格迁移的效果展示

先看一张效果图

 

image.png

 

二、风格迁移的基本原理:

1、损失函数方面:

损失函数有两部分组成:内容损失和风格损失:

图片内容:图片的主体,图片中比较突出的部分

图片风格:图片的纹理、色彩等

 

(1)内容损失content loss :原始图片的内容和生成图片的内容作欧式距离

image.png

其中,等式左侧表示在第l层中,原始图像(P)和生成图像(F)的举例,右侧是对应的最小二乘法表达式。Fij表示生成图像第 i 个feature map 的第 j 个输出值。

image

使用最小二乘法求导得出最小值,再让改的l层上生成的图片(F)逼近改层的原始图片(P)

 

(2)风格损失style loss使用类G矩阵代表图像的风格 :

 

image

当同一个维度上面的值相乘的时候原来越小相乘之后的值变得更小,原来越大相乘之后的中就变得越大;在不同维度上的关系也在相乘的表达当中表示出来。

image

因此,最终能够在保证内容的情况下,进行风格的迁移转换。

2、风格迁移的流程

输入风格图片和内容图片还有第三张图片,并改变第三张图片,使其与内容图片的内容间距和风格图片的风格间距最小化。最终得到生成的图片

三、代码部分:

  1. from __future__ import print_function
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. from PIL import Image
  7. import matplotlib.pyplot as plt
  8. import torchvision.transforms as transforms
  9. import torchvision.models as models
  10. import copy
  11. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 来判断是否有可用的GPU
  12. # 输出图像的所需尺寸
  13. imsize = 512 if torch.cuda.is_available() else 128 # 如果没有GPU,请使用小尺寸
  14. loader = transforms.Compose([
  15. transforms.Resize(imsize), # 缩放导入的图像
  16. transforms.ToTensor()]) # 将其转换为torch tensor
  17. def image_loader(image_name):
  18. image = Image.open(image_name)
  19. image = loader(image).unsqueeze(0) # 需要伪造的批次尺寸以适合网络的输入尺寸
  20. return image.to(device, torch.float)
  21. # style_img = image_loader("./images/picasso.jpg")
  22. # content_img = image_loader("./images/dancing.jpg")
  23. style_img = image_loader(r"C:\Users\Administrator\Desktop\styled_transfer\ori\18.jpg")
  24. content_img = image_loader(r"C:\Users\Administrator\Desktop\styled_transfer\ori\42.jpg")
  25. assert style_img.size() == content_img.size(), \
  26. "我们需要导入相同大小的样式和内容图像"
  27. unloader = transforms.ToPILImage() # 重新转换为PIL图像
  28. plt.ion()
  29. def imshow(tensor, title=None):
  30. image = tensor.cpu().clone() # 我们克隆张量不对其进行更改
  31. image = image.squeeze(0) # 删除假批次尺寸
  32. image = unloader(image)
  33. plt.imshow(image)
  34. if title is not None:
  35. plt.title(title)
  36. plt.pause(0.001) # 稍停一下,以便更新地块
  37. plt.figure()
  38. imshow(style_img, title='Style Image')
  39. plt.figure()
  40. imshow(content_img, title='Content Image')
  41. class ContentLoss(nn.Module):
  42. def __init__(self, target, ):
  43. super(ContentLoss, self).__init__()
  44. # 我们将目标内容与所使用的树“分离”
  45. # 动态计算梯度:这是一个规定值,
  46. # 不是变量。 否则,准则的前进方法
  47. # 将引发错误。
  48. self.target = target.detach()
  49. def forward(self, input):
  50. self.loss = F.mse_loss(input, self.target)
  51. return input
  52. def gram_matrix(input):
  53. a, b, c, d = input.size() # a=batch size(=1)
  54. # b=特征图数量
  55. # (c,d)=dimensions of a f. map (N=c*d)
  56. features = input.view(a * b, c * d) # 将FXML调整为\ hat FXML
  57. G = torch.matmul(features, features.t())
  58. # 我们将gram矩阵的值“规范化”
  59. # 除以每个要素图中的元素数量。
  60. return G.div(a * b * c * d)
  61. class StyleLoss(nn.Module):
  62. def __init__(self, target_feature):
  63. super(StyleLoss, self).__init__()
  64. self.target = gram_matrix(target_feature).detach()
  65. def forward(self, input):
  66. G = gram_matrix(input)
  67. self.loss = F.mse_loss(G, self.target)
  68. return input
  69. cnn = models.vgg19(pretrained=True).features.to(device).eval()
  70. cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
  71. cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
  72. # 创建一个模块来标准化输入图像,以便我们可以轻松地将其放入
  73. # nn.Sequential
  74. class Normalization(nn.Module):
  75. def __init__(self, mean, std):
  76. super(Normalization, self).__init__()
  77. # 查看均值和标准差以使其为[C x 1 x 1],以便它们可以
  78. # 直接使用形状为[B x C x H x W]的图像张量。
  79. # B是批量大小。 C是通道数。 H是高度,W是宽度。
  80. # self.mean = torch.tensor(mean).view(-1, 1, 1)
  81. # self.std = torch.tensor(std).view(-1, 1, 1)
  82. self.mean = mean.clone().detach().view(-1, 1, 1)
  83. self.std = std.clone().detach().view(-1, 1, 1)
  84. def forward(self, img):
  85. # normalize img
  86. return (img - self.mean) / self.std
  87. # 所需的深度层以计算样式/内容损失:
  88. content_layers_default = ['conv_4']
  89. style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
  90. def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
  91. style_img, content_img,
  92. content_layers=content_layers_default,
  93. style_layers=style_layers_default):
  94. cnn = copy.deepcopy(cnn)
  95. # 标准化模块
  96. normalization = Normalization(normalization_mean, normalization_std).to(device)
  97. # 只是为了获得对内容/样式的可迭代访问或列表
  98. # losses
  99. content_losses = []
  100. style_losses = []
  101. # 假设cnn是nn.Sequential,那么我们创建一个新的nn.Sequential
  102. # 放入应该顺序激活的模块
  103. model = nn.Sequential(normalization)
  104. i = 0 # 每当转换时就增加
  105. for layer in cnn.children():
  106. if isinstance(layer, nn.Conv2d): #如果对象的类型与参数二的类型(classinfo)相同则返回 True,否则返回 False
  107. i += 1
  108. name = 'conv_{}'.format(i)
  109. elif isinstance(layer, nn.ReLU):
  110. name = 'relu_{}'.format(i)
  111. # 旧版本与我们在下面插入的ContentLoss和StyleLoss不能很好地配合使用。
  112. # 因此,我们在这里替换为不适当的。
  113. layer = nn.ReLU(inplace=False)
  114. elif isinstance(layer, nn.MaxPool2d):
  115. name = 'pool_{}'.format(i)
  116. elif isinstance(layer, nn.BatchNorm2d):
  117. name = 'bn_{}'.format(i)
  118. else:
  119. raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
  120. model.add_module(name, layer)
  121. if name in content_layers:
  122. # 增加内容损失:
  123. target = model(content_img).detach()
  124. content_loss = ContentLoss(target)
  125. model.add_module("content_loss_{}".format(i), content_loss)
  126. content_losses.append(content_loss)
  127. if name in style_layers:
  128. # 增加样式损失:
  129. target_feature = model(style_img).detach()
  130. style_loss = StyleLoss(target_feature)
  131. model.add_module("style_loss_{}".format(i), style_loss)
  132. style_losses.append(style_loss)
  133. # 现在我们在最后一次内容和样式丢失后修剪图层
  134. for i in range(len(model) - 1, -1, -1):
  135. if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
  136. break
  137. model = model[:(i + 1)]
  138. return model, style_losses, content_losses
  139. input_img = content_img.clone()
  140. # 如果要使用白噪声,请取消注释以下行:
  141. # input_img = torch.randn(content_img.data.size(), device=device)
  142. # 将原始输入图像添加到图中:
  143. plt.figure()
  144. imshow(input_img, title='Input Image')
  145. def get_input_optimizer(input_img):
  146. # 此行显示输入是需要渐变的参数
  147. optimizer = optim.LBFGS([input_img.requires_grad_()])
  148. return optimizer
  149. def run_style_transfer(cnn, normalization_mean, normalization_std,
  150. content_img, style_img, input_img, num_steps=300,
  151. style_weight=1000000, content_weight=1):
  152. """Run the style transfer."""
  153. print('Building the style transfer model..')
  154. model, style_losses, content_losses = get_style_model_and_losses(cnn,
  155. normalization_mean, normalization_std, style_img,
  156. content_img)
  157. optimizer = get_input_optimizer(input_img)
  158. print('Optimizing..')
  159. run = [0]
  160. while run[0] <= num_steps:
  161. def closure():
  162. # 更正更新后的输入图像的值
  163. input_img.data.clamp_(0, 1)
  164. optimizer.zero_grad()
  165. model(input_img)
  166. style_score = 0
  167. content_score = 0
  168. for sl in style_losses:
  169. style_score += sl.loss
  170. for cl in content_losses:
  171. content_score += cl.loss
  172. style_score *= style_weight
  173. content_score *= content_weight
  174. loss = style_score + content_score
  175. loss.backward()
  176. run[0] += 1
  177. if run[0] % 50 == 0:
  178. print("run {}:".format(run))
  179. print('Style Loss : {:4f} Content Loss: {:4f}'.format(
  180. style_score.item(), content_score.item()))
  181. print()
  182. return style_score + content_score
  183. optimizer.step(closure)
  184. # 最后更正...
  185. input_img.data.clamp_(0, 1)
  186. return input_img
  187. output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
  188. content_img, style_img, input_img)
  189. plt.figure()
  190. imshow(output, title='Output Image')
  191. # sphinx_gallery_thumbnail_number = 4
  192. plt.ioff()
  193. plt.show()

参考资料:

https://blog.puuuq.cn/index.php/2019/10/03/52.html

https://www.zhihu.com/question/49805962/answer/130549737

https://www.cnblogs.com/xiaoyh/p/11932095.html

论文地址:https://arxiv.org/abs/1711.09020

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/530392
推荐阅读
相关标签
  

闽ICP备14008679号