当前位置:   article > 正文

图像恢复SwinIR: Image Restoration Using Swin Transformer

swinir: image restoration using swin transformer

目录

目录

前言

一、Introduction

二、Network Architecture

2.1 整体框架

2.2 浅层特征提取

2.3 深层特征提取

2.4 图像重建模块

三、实验结果

1、 经典图像超分(Classical image SR)

2、 轻量级图像超分(Lightweight image SR)

3、 真实世界图像超分(Real-world image SR)

四、主要代码理解

1、 SwinIR

2、 MLP

3、 Patch Embeddings

4、 Window Attention

5、残差 Swin Transformer 块(RSTB)

7、测试实例

总结


前言

论文:https://arxiv.org/pdf/2108.10257.pdf

代码:https://github.com/JingyunLiang/SwinIR

参考:图像恢复 SWinIR : 彻底理解论文和源代码 (注释详尽)_听 风、的博客-CSDN博客

           图像超分辨率:SwinIR学习笔记 - 知乎


一、Introduction

在图像超分辨率、图像去噪、压缩等图像修复(Image restoration)任务中,卷积神经网络目前仍然是主流。但是卷积神经网络有以下两个缺点:

(1)图像和卷积核之间的交互是与内容无关的;

(2)在局部处理的原则下,卷积对于长距离依赖建模是无效的。

作为卷积的一个替代操作,Transformer设计了自注意力机制来捕获全局信息,但视觉Transformer因为需要划分patch,因此具有以下两个缺点:

(1)边界像素不能利用patch之外的邻近像素进行图像恢复;

(2)恢复后的图像可能会在每个patch周围引入边界伪影,这个问题能够通过patch overlapping缓解,但会增加计算量。

Swin Transformer结合了卷积和Transformer的优势,因此本文基于Swin Transformer提出了一种图像修复模型SwinIR。和现有模型相比,SwinIR具有更少的参数,且取得了更好的效果。

二、Network Architecture

2.1 整体框架

SwinIR的网络结构主要分为3部分,分别是浅层特征提取、深层特征提取和高质量图像重建。其中前后两个模块都是基于CNN的,中间的模块则主要使用Swin Transformer。

2.2 浅层特征提取

浅层特征提取只使用一层卷积进行提取。

使用一个3x3卷积HSF提取浅层特征F_0

2.3 深层特征提取

深层特征提取模块由若干个残差Swin Transformer块(RSTB)和卷积块构成,具体结构如下图:

(1)首先将来自浅层特征提取模块的特征图分割成多个不重叠的patch embeddings;

(2)再通过多个串联的残差Swin Transformer块(RSTB);

(3)将多个不重叠的patch embeddings重新组合成与输入特征图分辨率一样;

(4)最后通过一个卷积层(1层或3层卷积)输出;

(5)在每个RSTB中都引入残差连接。

残差Swin Transformer块(RSTB)中的STL就是Swin Transformer Layer,具体结构如下图:

(1)首先通过一个归一化层LayerNorm;

(2)再通过多头自注意力(Multi-head Self Attention)模块;

(3)在多头自注意力结尾引入残差;

(4)再通过一个归一化层LayerNorm;

(5)最后通过一个多层感知机MLP;

(6)结尾同样引入残差。

2.4 图像重建模块

图像重建模块是卷积+上采样的组合。在论文中提出4种结构:

(1)经典超分(卷积 + pixelshuffle 上采样 + 卷积);

(2)轻量超分(卷积 + pixelshuffle 上采样);

(3)真实图像超分(卷积 + 卷积插值上采样 + 卷积插值上采样 + 卷积);

(4)图像去噪和JPEG压缩去伪影(卷积 + 引入残差)。

三、实验结果

部分实验结果如下所示(仅选取了图像超分辨率相关的实验结果),包括经典图像超分(Classical image SR)、轻量级图像超分(Lightweight image SR)、真实世界图像超分(Real-world image SR)。

1、 经典图像超分(Classical image SR)

作者对比了基于卷积神经网络的模型(DBPN、RCAN、RRDB、SAN、IGNN、HAN、NLSA)和最新的基于Transformer的模型(IPT)。得益于局部窗口自注意力机制和卷积操作的归纳偏置,SwinIR的参数量减少至11.8M,明显少于IPT的115.5M,甚至少于部分卷积神经网络的模型;模型的训练难度也随之减少,不再需要ImageNet那样的大数据集来训练模型。仅使用DIV2K数据集训练时,SwinIR的精度就超过了卷积神经网络模型;再加上Flickr2K数据集后,精度就超越了使用的ImageNet训练、115.5M参数的IPT模型。

2、 轻量级图像超分(Lightweight image SR)

作者对比了几个轻量级的图像超分模型(CARN、FALSR-A、IMDN、LAPAR-A、LatticeNet),如下图所示,在相似的计算量和参数量的前提下,SwinIR超越了诸多轻量级超分模型,显然SwinIR更加高效。

3、 真实世界图像超分(Real-world image SR)

图像超分辨率的最终目的是应用于真实世界。由于真实世界图像超分任务没有GT图像,因此作者对比了几种真实世界图像超分模型的可视化结果(ESRGAN、RealSR、BSRGAN、Real-ESRGAN)。SwinIR能够产生锐度高的清晰图像。

四、主要代码理解

1、 SwinIR

SwinIR主要由浅层特征提取、深层特征提取和高质量图像重建模块组成。

  1. # SWinIR
  2. class SwinIR(nn.Module):
  3. r""" SwinIR
  4. 基于 Swin Transformer 的图像恢复网络.
  5. 输入:
  6. img_size (int | tuple(int)): 输入图像的大小,默认为 64*64.
  7. patch_size (int | tuple(int)): patch 的大小,默认为 1.
  8. in_chans (int): 输入图像的通道数,默认为 3.
  9. embed_dim (int): Patch embedding 的维度,默认为 96.
  10. depths (tuple(int)): Swin Transformer 层的深度.
  11. num_heads (tuple(int)): 在不同层注意力头的个数.
  12. window_size (int): 窗口大小,默认为 7.
  13. mlp_ratio (float): MLP隐藏层特征图通道与嵌入层特征图通道的比,默认为 4.
  14. qkv_bias (bool): 给 query, key, value 添加可学习的偏置,默认为 True.
  15. qk_scale (float): 重写默认的缩放因子,默认为 None.
  16. drop_rate (float): 随机丢弃神经元,丢弃率默认为 0.
  17. attn_drop_rate (float): 注意力权重的丢弃率,默认为 0.
  18. drop_path_rate (float): 深度随机丢弃率,默认为 0.1.
  19. norm_layer (nn.Module): 归一化操作,默认为 nn.LayerNorm.
  20. ape (bool): patch embedding 添加绝对位置 embedding,默认为 False.
  21. patch_norm (bool): 在 patch embedding 后添加归一化操作,默认为 True.
  22. use_checkpoint (bool): 是否使用 checkpointing 来节省显存,默认为 False.
  23. upscale: 放大因子, 2/3/4/8 适合图像超分, 1 适合图像去噪和 JPEG 压缩去伪影
  24. img_range: 灰度值范围, 1 或者 255.
  25. upsampler: 图像重建方法的选择模块,可选择 pixelshuffle, pixelshuffledirect, nearest+conv 或 None.
  26. resi_connection: 残差连接之前的卷积块, 可选择 1conv 或 3conv.
  27. """
  28. def __init__(self, img_size=64, patch_size=1, in_chans=3,
  29. embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
  30. window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
  31. drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
  32. norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
  33. use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
  34. **kwargs):
  35. super(SwinIR, self).__init__()
  36. num_in_ch = in_chans # 输入图片通道数
  37. num_out_ch = in_chans # 输出图片通道数
  38. num_feat = 64 # 特征图通道数
  39. self.img_range = img_range # 灰度值范围:[0, 1] or [0, 255]
  40. if in_chans == 3: # 如果输入是RGB图像
  41. rgb_mean = (0.4488, 0.4371, 0.4040) # 数据集RGB均值
  42. self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) # 转为[1, 3, 1, 1]的张量
  43. else: # 否则灰度图
  44. self.mean = torch.zeros(1, 1, 1, 1) # 构造[1, 1, 1, 1]的张量
  45. self.upscale = upscale # 图像放大倍数,超分(2/3/4/8),去噪(1)
  46. self.upsampler = upsampler # 上采样方法
  47. self.window_size = window_size # 注意力窗口的大小
  48. #######################################################################################
  49. ################################### 1, 浅层特征提取 ###################################
  50. self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) # 输入卷积层
  51. ##########################################################################################
  52. ################################### 2, 深层特征提取 ######################################
  53. self.num_layers = len(depths) # Swin Transformer 层的个数
  54. self.embed_dim = embed_dim # 嵌入层特征图的通道数
  55. self.ape = ape # patch embedding 添加绝对位置 embedding,默认为 False.
  56. self.patch_norm = patch_norm # 在 patch embedding 后添加归一化操作,默认为 True.
  57. self.num_features = embed_dim # 特征图的通道数
  58. self.mlp_ratio = mlp_ratio # MLP隐藏层特征图通道与嵌入层特征图通道的比
  59. # 将图像分割成多个不重叠的patch
  60. self.patch_embed = PatchEmbed(
  61. img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
  62. norm_layer=norm_layer if self.patch_norm else None)
  63. num_patches = self.patch_embed.num_patches # 分割得到patch的个数
  64. patches_resolution = self.patch_embed.patches_resolution # 分割得到patch的分辨率
  65. self.patches_resolution = patches_resolution
  66. # 将多个不重叠的patch合并成图像
  67. self.patch_unembed = PatchUnEmbed(
  68. img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
  69. norm_layer=norm_layer if self.patch_norm else None)
  70. # 绝对位置嵌入
  71. if self.ape:
  72. # 结构为 [1,patch个数, 嵌入层特征图的通道数] 的参数
  73. self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
  74. trunc_normal_(self.absolute_pos_embed, std=.02) # 截断正态分布,限制标准差为0.02
  75. self.pos_drop = nn.Dropout(p=drop_rate) # 以drop_rate为丢弃率随机丢弃神经元,默认不丢弃
  76. # 随机深度衰减规律,默认为 [0, 0.1] 进行24等分后的列表
  77. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
  78. # Residual Swin Transformer blocks (RSTB)
  79. # 残差 Swin Transformer 块 (RSTB)
  80. self.layers = nn.ModuleList() # 创建一个ModuleList实例对象,也就是多个 RSTB
  81. for i_layer in range(self.num_layers): # 循环 Swin Transformer 层的个数次
  82. # 实例化 RSTB
  83. layer = RSTB(dim=embed_dim,
  84. input_resolution=(patches_resolution[0],
  85. patches_resolution[1]),
  86. depth=depths[i_layer],
  87. num_heads=num_heads[i_layer],
  88. window_size=window_size,
  89. mlp_ratio=self.mlp_ratio,
  90. qkv_bias=qkv_bias, qk_scale=qk_scale,
  91. drop=drop_rate, attn_drop=attn_drop_rate,
  92. drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
  93. norm_layer=norm_layer,
  94. downsample=None,
  95. use_checkpoint=use_checkpoint,
  96. img_size=img_size,
  97. patch_size=patch_size,
  98. resi_connection=resi_connection
  99. )
  100. self.layers.append(layer) # 将 RSTB 对象插入 ModuleList 中
  101. self.norm = norm_layer(self.num_features) # 归一化操作,默认 LayerNorm
  102. # 在深层特征提取网络中加入卷积块,保持特征图通道数不变
  103. if resi_connection == '1conv': # 1层卷积
  104. self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
  105. elif resi_connection == '3conv': # 3层卷积
  106. # 为了减少参数使用和节约显存,采用瓶颈结构
  107. self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), # 降维
  108. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  109. nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
  110. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  111. nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) # 升维
  112. # 高质量图像重建模块
  113. if self.upsampler == 'pixelshuffle': # pixelshuffle 上采样
  114. # 适合经典超分
  115. self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
  116. nn.LeakyReLU(inplace=True))
  117. self.upsample = Upsample(upscale, num_feat) # 上采样
  118. self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) # 输出卷积层
  119. elif self.upsampler == 'pixelshuffledirect': # 一步是实现既上采样也降维
  120. # 适合轻量级充分,可以减少参数量(一步是实现既上采样也降维)
  121. self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
  122. (patches_resolution[0], patches_resolution[1]))
  123. elif self.upsampler == 'nearest+conv': # 最近邻插值上采样
  124. # 适合真实图像超分
  125. assert self.upscale == 4, 'only support x4 now.' # 声明目前仅支持4倍超分重建
  126. # 上采样之前的卷积层
  127. self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
  128. nn.LeakyReLU(inplace=True))
  129. # 第一次上采样卷积(直接对输入做最近邻插值变为2倍图像)
  130. self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
  131. # 第二次上采样卷积(直接对输入做最近邻插值变为2倍图像)
  132. self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
  133. self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) # 对上采样完成的图像再做卷积
  134. self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) # 激活层
  135. self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) # 输出卷积层
  136. else:
  137. # 适合图像去噪和 JPEG 压缩去伪影
  138. self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
  139. self.apply(self._init_weights) # 初始化网络参数
  140. # 初始化网络参数
  141. def _init_weights(self, m):
  142. if isinstance(m, nn.Linear): # 判断是否为线性 Linear 层
  143. trunc_normal_(m.weight, std=.02) # 截断正态分布,限制标准差为 0.02
  144. if m.bias is not None: # 如果设置了偏置
  145. nn.init.constant_(m.bias, 0) # 初始化偏置为 0
  146. elif isinstance(m, nn.LayerNorm): # 判断是否为归一化 LayerNorm 层
  147. nn.init.constant_(m.bias, 0) # 初始化偏置为 0
  148. nn.init.constant_(m.weight, 1.0) # 初始化权重系数为 1
  149. # 检查图片(准确说是张量)的大小
  150. def check_image_size(self, x):
  151. _, _, h, w = x.size() # 张量 x 的高和宽
  152. # h 维度要填充的个数
  153. mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
  154. # w 维度要填充的个数
  155. mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
  156. # 右填充 mod_pad_w 个值,下填充 mod_pad_h 个值,模式为反射(可以理解为以 x 的维度末尾为轴对折)
  157. x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
  158. return x
  159. # 深层特征提取网络的前向传播
  160. def forward_features(self, x):
  161. x_size = (x.shape[2], x.shape[3]) # 张量 x 的高和宽
  162. x = self.patch_embed(x) # 分割 x 为多个不重叠的 patch embeddings
  163. if self.ape: # 绝对位置 embedding
  164. x = x + self.absolute_pos_embed # x 加上对应的绝对位置 embedding
  165. x = self.pos_drop(x) # 随机将x中的部分元素置 0
  166. for layer in self.layers:
  167. x = layer(x, x_size) # x 通过多个串联的 RSTB
  168. x = self.norm(x) # 对 RSTB 的输出进行归一化
  169. x = self.patch_unembed(x, x_size) # 将多个不重叠的 patch 合并成图像
  170. return x
  171. # SWinIR 的前向传播
  172. def forward(self, x):
  173. H, W = x.shape[2:] # 输入图片的高和宽
  174. x = self.check_image_size(x) # 检查图片的大小,使高宽满足 window_size 的整数倍
  175. self.mean = self.mean.type_as(x) # RGB 均值的类型同 x 一致
  176. x = (x - self.mean) * self.img_range # x 减去 RGB 均值再乘以输入的最大灰度值
  177. if self.upsampler == 'pixelshuffle': # pixelshuffle 上采样方法
  178. # 适合经典超分
  179. x = self.conv_first(x) # 输入卷积层
  180. x = self.conv_after_body(self.forward_features(x)) + x # 深度特征提取网络,引入残差
  181. x = self.conv_before_upsample(x) # 上采样前进行卷积
  182. x = self.conv_last(self.upsample(x)) # 上采样后再通过输出卷积层
  183. elif self.upsampler == 'pixelshuffledirect': # 一步是实现既上采样也降维
  184. # 适合轻量级超分
  185. x = self.conv_first(x) # 输入卷积层
  186. x = self.conv_after_body(self.forward_features(x)) + x # 深度特征提取网络,引入残差
  187. x = self.upsample(x) # 上采样并降维后输出
  188. elif self.upsampler == 'nearest+conv': # 最近邻插值上采样方法
  189. # 适合真实图像超分,只适合 4 倍超分
  190. x = self.conv_first(x) # 输入卷积层
  191. x = self.conv_after_body(self.forward_features(x)) + x # 深度特征提取网络,引入残差
  192. x = self.conv_before_upsample(x) # 上采样前进行卷积
  193. # 第一次上采样 2 倍
  194. x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
  195. # 第二次上采样 2 倍
  196. x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
  197. x = self.conv_last(self.lrelu(self.conv_hr(x))) # 输出卷积层
  198. else:
  199. # 适合图像去噪和 JPEG 压缩去伪影
  200. x_first = self.conv_first(x) # 输入卷积层
  201. res = self.conv_after_body(self.forward_features(x_first)) + x_first # 深度特征提取网络,引入残差
  202. x = x + self.conv_last(res) # 输出卷积层,引入残差
  203. x = x / self.img_range + self.mean # 最后的 x 除以灰度值范围再加上 RGB 均值
  204. return x[:, :, :H*self.upscale, :W*self.upscale] # 返回输出 x
2、 MLP

多层感知机MLP是Transformer比较基础的部分,具体原理也很简单。

  1. # 多层感知机
  2. class Mlp(nn.Module):
  3. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  4. super().__init__()
  5. out_features = out_features or in_features # 输入特征的维度
  6. hidden_features = hidden_features or in_features # 隐藏特征维度
  7. self.fc1 = nn.Linear(in_features, hidden_features) # 线性层
  8. self.act = act_layer() # 激活函数
  9. self.fc2 = nn.Linear(hidden_features, out_features) # 线性层
  10. self.drop = nn.Dropout(drop) # 随机丢弃神经元,丢弃率默认为 0
  11. # 定义前向传播
  12. def forward(self, x):
  13. x = self.fc1(x)
  14. x = self.act(x)
  15. x = self.drop(x)
  16. x = self.fc2(x)
  17. x = self.drop(x)
  18. return x
3、 Patch Embeddings

主要的操作就是将原始2维图像(特征图的一个plane或者说一个channel)转变为1维的patch embeddings,通过Swin Transformer学习处理之后再重新组合成与原来特征图结构一致的新特征图。

(1)将 2 维图像转变为 1 维patch embeddings。

  1. # 图像转成 Patch Embeddings
  2. class PatchEmbed(nn.Module):
  3. r""" Image to Patch Embedding
  4. 输入:
  5. img_size (int): 图像的大小,默认为 224*224.
  6. patch_size (int): Patch token 的大小,默认为 4*4.
  7. in_chans (int): 输入图像的通道数,默认为 3.
  8. embed_dim (int): 线性 projection 输出的通道数,默认为 96.
  9. norm_layer (nn.Module, optional): 归一化层, 默认为N None.
  10. """
  11. def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
  12. super().__init__()
  13. img_size = to_2tuple(img_size) # 图像的大小,默认为 224*224
  14. patch_size = to_2tuple(patch_size) # Patch token 的大小,默认为 4*4
  15. patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # patch 的分辨率
  16. self.img_size = img_size
  17. self.patch_size = patch_size
  18. self.patches_resolution = patches_resolution
  19. self.num_patches = patches_resolution[0] * patches_resolution[1] # patch 的个数,num_patches
  20. self.in_chans = in_chans # 输入图像的通道数
  21. self.embed_dim = embed_dim # 线性 projection 输出的通道数
  22. if norm_layer is not None:
  23. self.norm = norm_layer(embed_dim) # 归一化
  24. else:
  25. self.norm = None
  26. # 定义前向传播
  27. def forward(self, x):
  28. x = x.flatten(2).transpose(1, 2) # 结构为 [B, num_patches, C]
  29. if self.norm is not None:
  30. x = self.norm(x) # 归一化
  31. return x

(2)将 1 维 patch embeddings 转变为 2 维图像。

  1. # 从 Patch Embeddings 组合图像
  2. class PatchUnEmbed(nn.Module):
  3. r""" Image to Patch Unembedding
  4. 输入:
  5. img_size (int): 图像的大小,默认为 224*224.
  6. patch_size (int): Patch token 的大小,默认为 4*4.
  7. in_chans (int): 输入图像的通道数,默认为 3.
  8. embed_dim (int): 线性 projection 输出的通道数,默认为 96.
  9. norm_layer (nn.Module, optional): 归一化层, 默认为N None.
  10. """
  11. def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
  12. super().__init__()
  13. img_size = to_2tuple(img_size) # 图像的大小,默认为 224*224
  14. patch_size = to_2tuple(patch_size) # Patch token 的大小,默认为 4*4
  15. patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # patch 的分辨率
  16. self.img_size = img_size
  17. self.patch_size = patch_size
  18. self.patches_resolution = patches_resolution
  19. self.num_patches = patches_resolution[0] * patches_resolution[1] # patch 的个数,num_patches
  20. self.in_chans = in_chans # 输入图像的通道数
  21. self.embed_dim = embed_dim # 线性 projection 输出的通道数
  22. def forward(self, x, x_size):
  23. B, HW, C = x.shape # 输入 x 的结构
  24. x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # 输出结构为 [B, Ph*Pw, C]
  25. return x
4、 Window Attention

采用窗口注意力来减轻传统Transformer的全局注意力带来的计算负担,将注意力的计算限制在每一个窗口里,在每个窗口里其实还是原始的多头自注意力。

(1)窗口分割

  1. # 将输入分割为多个不重叠窗口
  2. def window_partition(x, window_size):
  3. """
  4. 输入:
  5. x: (B, H, W, C)
  6. window_size (int): window size # 窗口的大小
  7. 返回:
  8. windows: (num_windows*B, window_size, window_size, C) # 每一个 batch 有单独的 windows
  9. """
  10. B, H, W, C = x.shape # 输入的 batch 个数,高,宽,通道数
  11. # 将输入 x 重构为结构 [batch 个数,高方向的窗口个数,窗口大小,宽方向的窗口个数,窗口大小,通道数] 的张量
  12. x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
  13. # 交换重构后 x 的第 3和4 维度, 5和6 维度,再次重构为结构 [高和宽方向的窗口个数乘以 batch 个数,窗口大小,窗口大小,通道数] 的张量
  14. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  15. return windows
  16. # 这里比较有意思,不太理解的可以给个初始值,比如 x = torch.randn([1, 14, 28, 3])

(2)窗口注意力

(这里的相对位置索引可以参考:图解Swin Transformer - 知乎

  1. # 窗口注意力
  2. class WindowAttention(nn.Module):
  3. r""" 基于有相对位置偏差的多头自注意力窗口,支持移位的(shifted)或者不移位的(non-shifted)窗口.
  4. 输入:
  5. dim (int): 输入特征的维度.
  6. window_size (tuple[int]): 窗口的大小.
  7. num_heads (int): 注意力头的个数.
  8. qkv_bias (bool, optional): 给 query, key, value 添加可学习的偏置,默认为 True.
  9. qk_scale (float | None, optional): 重写默认的缩放因子 scale.
  10. attn_drop (float, optional): 注意力权重的丢弃率,默认为 0.0.
  11. proj_drop (float, optional): 输出的丢弃率,默认为 0.0.
  12. """
  13. def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
  14. super().__init__()
  15. self.dim = dim # 输入特征的维度
  16. self.window_size = window_size # 窗口的高 Wh,宽 Ww
  17. self.num_heads = num_heads # 注意力头的个数
  18. head_dim = dim // num_heads # 注意力头的维度
  19. self.scale = qk_scale or head_dim ** -0.5 # 缩放因子 scale
  20. # 定义相对位置偏移的参数表,结构为 [2*Wh-1 * 2*Ww-1, num_heads]
  21. self.relative_position_bias_table = nn.Parameter(
  22. torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
  23. # 获取窗口内每个 token 的成对的相对位置索引
  24. coords_h = torch.arange(self.window_size[0]) # 高维度上的坐标 (0, 7)
  25. coords_w = torch.arange(self.window_size[1]) # 宽维度上的坐标 (0, 7)
  26. coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 坐标,结构为 [2, Wh, Ww]
  27. coords_flatten = torch.flatten(coords, 1) # 重构张量结构为 [2, Wh*Ww]
  28. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 相对坐标,结构为 [2, Wh*Ww, Wh*Ww]
  29. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # 交换维度,结构为 [Wh*Ww, Wh*Ww, 2]
  30. relative_coords[:, :, 0] += self.window_size[0] - 1 # 第1个维度移位
  31. relative_coords[:, :, 1] += self.window_size[1] - 1 # 第1个维度移位
  32. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 # 第1个维度的值乘以 2倍的 Ww,再减 1
  33. relative_position_index = relative_coords.sum(-1) # 相对位置索引,结构为 [Wh*Ww, Wh*Ww]
  34. self.register_buffer("relative_position_index", relative_position_index) # 保存数据,不再更新
  35. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 线性层,特征维度变为原来的 3倍
  36. self.attn_drop = nn.Dropout(attn_drop) # 随机丢弃神经元,丢弃率默认为 0.0
  37. self.proj = nn.Linear(dim, dim) # 线性层,特征维度不变
  38. self.proj_drop = nn.Dropout(proj_drop) # 随机丢弃神经元,丢弃率默认为 0.0
  39. trunc_normal_(self.relative_position_bias_table, std=.02) # 截断正态分布,限制标准差为 0.02
  40. self.softmax = nn.Softmax(dim=-1) # 激活函数 softmax
  41. # 定义前向传播
  42. def forward(self, x, mask=None):
  43. """
  44. 输入:
  45. x: 输入特征图,结构为 [num_windows*B, N, C]
  46. mask: (0/-inf) mask, 结构为 [num_windows, Wh*Ww, Wh*Ww] 或者没有 mask
  47. """
  48. B_, N, C = x.shape # 输入特征图的结构
  49. # 将特征图的通道维度按照注意力头的个数重新划分,并再做交换维度操作
  50. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  51. q, k, v = qkv[0], qkv[1], qkv[2] # 方便后续写代码,重新赋值
  52. # q 乘以缩放因子
  53. q = q * self.scale
  54. # @ 代表常规意义上的矩阵相乘
  55. attn = (q @ k.transpose(-2, -1)) # q 和 k 相乘后并交换最后两个维度
  56. # 相对位置偏移,结构为 [Wh*Ww, Wh*Ww, num_heads]
  57. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  58. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
  59. # 相对位置偏移交换维度,结构为 [num_heads, Wh*Ww, Wh*Ww]
  60. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  61. attn = attn + relative_position_bias.unsqueeze(0) # 带相对位置偏移的注意力图
  62. if mask is not None: # 判断是否有 mask
  63. nW = mask.shape[0] # mask 的宽
  64. # 注意力图与 mask 相加
  65. attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  66. attn = attn.view(-1, self.num_heads, N, N) # 恢复注意力图原来的结构
  67. attn = self.softmax(attn) # 激活注意力图 [0, 1] 之间
  68. else:
  69. attn = self.softmax(attn)
  70. attn = self.attn_drop(attn) # 随机设置注意力图中的部分值为 0
  71. # 注意力图与 v 相乘得到新的注意力图
  72. x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
  73. x = self.proj(x) # 通过线性层
  74. x = self.proj_drop(x) # 随机设置新注意力图中的部分值为 0
  75. return x

(3)窗口合并

  1. # 将多个不重叠窗口重新合并
  2. def window_reverse(windows, window_size, H, W):
  3. """
  4. 输入:
  5. windows: (num_windows*B, window_size, window_size, C) # 分割得到的窗口(已处理)
  6. window_size (int): Window size # 窗口大小
  7. H (int): Height of image # 原分割窗口前特征图的高
  8. W (int): Width of image # 原分割窗口前特征图的宽
  9. 返回:
  10. x: (B, H, W, C) # 返回与分割前特征图结构一样的结果
  11. """
  12. # 以下就是分割窗口的逆向操作,不多解释
  13. B = int(windows.shape[0] / (H * W / window_size / window_size))
  14. x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
  15. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  16. return x
5、残差 Swin Transformer 块(RSTB)

SwinIR 主要是使用 Swin Transformer的思想来实现,残差 Swin Transformer块(RSTB)可以理解为:

(1)Swin Transformer 块是RSTB的基础组件;

(2)多个Swin Transformer块构成基础网络;

(3)基础网络结尾处加上卷积操作后再引入残差构成RSTB。

(1)Swin Transformer 块

  1. # Swin Transformer 块
  2. class SwinTransformerBlock(nn.Module):
  3. """
  4. 输入:
  5. dim (int): 输入特征的维度.
  6. input_resolution (tuple[int]): 输入特征图的分辨率.
  7. num_heads (int): 注意力头的个数.
  8. window_size (int): 窗口的大小.
  9. shift_size (int): SW-MSA 的移位值.
  10. mlp_ratio (float): 多层感知机隐藏层的维度和嵌入层的比.
  11. qkv_bias (bool, optional): 给 query, key, value 添加一个可学习偏置,默认为 True.
  12. qk_scale (float | None, optional): 重写默认的缩放因子 scale.
  13. drop (float, optional): 随机神经元丢弃率,默认为 0.0.
  14. attn_drop (float, optional): 注意力图随机丢弃率,默认为 0.0.
  15. drop_path (float, optional): 深度随机丢弃率,默认为 0.0.
  16. act_layer (nn.Module, optional): 激活函数,默认为 nn.GELU.
  17. norm_layer (nn.Module, optional): 归一化操作,默认为 nn.LayerNorm.
  18. """
  19. def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
  20. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
  21. act_layer=nn.GELU, norm_layer=nn.LayerNorm):
  22. super().__init__()
  23. self.dim = dim # 输入特征的维度
  24. self.input_resolution = input_resolution # 输入特征图的分辨率
  25. self.num_heads = num_heads # 注意力头的个数
  26. self.window_size = window_size # 窗口的大小
  27. self.shift_size = shift_size # SW-MSA 的移位大小
  28. self.mlp_ratio = mlp_ratio # 多层感知机隐藏层的维度和嵌入层的比
  29. if min(self.input_resolution) <= self.window_size: # 如果输入分辨率小于等于窗口大小
  30. self.shift_size = 0 # 移位大小为 0
  31. self.window_size = min(self.input_resolution) # 窗口大小等于输入分辨率大小
  32. # 断言移位值必须小于等于窗口的大小
  33. assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
  34. self.norm1 = norm_layer(dim) # 归一化层
  35. # 窗口注意力
  36. self.attn = WindowAttention(
  37. dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
  38. qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  39. # 如果丢弃率大于 0 则进行随机丢弃,否则进行占位(不做任何操作)
  40. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  41. self.norm2 = norm_layer(dim) # 归一化层
  42. mlp_hidden_dim = int(dim * mlp_ratio) # 多层感知机隐藏层维度
  43. # 多层感知机
  44. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  45. if self.shift_size > 0: # 如果移位值大于 0
  46. attn_mask = self.calculate_mask(self.input_resolution) # 计算注意力 mask
  47. else:
  48. attn_mask = None # 注意力 mask 赋空
  49. self.register_buffer("attn_mask", attn_mask) # 保存注意力 mask,不参与更新
  50. # 计算注意力 mask
  51. def calculate_mask(self, x_size):
  52. H, W = x_size # 特征图的高宽
  53. img_mask = torch.zeros((1, H, W, 1)) # 新建张量,结构为 [1, H, W, 1]
  54. # 以下两 slices 中的数据是索引,具体缘由尚未搞懂
  55. h_slices = (slice(0, -self.window_size), # 索引 0 到索引倒数第 window_size
  56. slice(-self.window_size, -self.shift_size), # 索引倒数第 window_size 到索引倒数第 shift_size
  57. slice(-self.shift_size, None)) # 索引倒数第 shift_size 后所有索引
  58. w_slices = (slice(0, -self.window_size),
  59. slice(-self.window_size, -self.shift_size),
  60. slice(-self.shift_size, None))
  61. cnt = 0
  62. for h in h_slices:
  63. for w in w_slices:
  64. img_mask[:, h, w, :] = cnt # 将 img_mask 中 h, w 对应索引范围的值置为 cnt
  65. cnt += 1 # 加 1
  66. mask_windows = window_partition(img_mask, self.window_size) # 窗口分割,返回值结构为 [nW, window_size, window_size, 1]
  67. mask_windows = mask_windows.view(-1, self.window_size * self.window_size) # 重构结构为二维张量,列数为 [window_size*window_size]
  68. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # 增加第 2 维度减去增加第 3 维度的注意力 mask
  69. # 用浮点数 -100. 填充注意力 mask 中值不为 0 的元素,再用浮点数 0. 填充注意力 mask 中值为 0 的元素
  70. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  71. return attn_mask
  72. # 定义前向传播
  73. def forward(self, x, x_size):
  74. H, W = x_size # 输入特征图的分辨率
  75. B, L, C = x.shape # 输入特征的 batch 个数,长度和维度
  76. # assert L == H * W, "input feature has wrong size"
  77. shortcut = x
  78. x = self.norm1(x) # 归一化
  79. x = x.view(B, H, W, C) # 重构 x 为结构 [B, H, W, C]
  80. # 循环移位
  81. if self.shift_size > 0: # 如果移位值大于 0
  82. # 第 0 维度上移 shift_size 位,第 1 维度左移 shift_size 位
  83. shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  84. else:
  85. shifted_x = x # 不移位
  86. # 对移位操作得到的特征图分割窗口, nW 是窗口的个数
  87. x_windows = window_partition(shifted_x, self.window_size) # 结构为 [nW*B, window_size, window_size, C]
  88. x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # 结构为 [nW*B, window_size*window_size, C]
  89. # W-MSA/SW-MSA, 用在分辨率是窗口大小的整数倍的图像上进行测试
  90. if self.input_resolution == x_size: # 输入分辨率与设定一致,不需要重新计算注意力 mask
  91. attn_windows = self.attn(x_windows, mask=self.attn_mask) # 注意力窗口,结构为 [nW*B, window_size*window_size, C]
  92. else: # 输入分辨率与设定不一致,需要重新计算注意力 mask
  93. attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
  94. # 合并窗口
  95. attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # 结构为 [-1, window_size, window_size, C]
  96. shifted_x = window_reverse(attn_windows, self.window_size, H, W) # 结构为 [B, H', W', C]
  97. # 逆向循环移位
  98. if self.shift_size > 0:
  99. # 第 0 维度下移 shift_size 位,第 1 维度右移 shift_size 位
  100. x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  101. else:
  102. x = shifted_x # 不逆向移位
  103. x = x.view(B, H * W, C) # 结构为 [B, H*W, C]
  104. # FFN
  105. x = shortcut + self.drop_path(x) # 对 x 做 dropout,引入残差
  106. x = x + self.drop_path(self.mlp(self.norm2(x))) # 归一化后通过 MLP,再做 dropout,引入残差
  107. return x

(2)基础网络

  1. # 单阶段的 SWin Transformer 基础层
  2. class BasicLayer(nn.Module):
  3. """
  4. 输入:
  5. dim (int): 输入特征的维度.
  6. input_resolution (tuple[int]): 输入分辨率.
  7. depth (int): SWin Transformer 块的个数.
  8. num_heads (int): 注意力头的个数.
  9. window_size (int): 本地(当前块中)窗口的大小.
  10. mlp_ratio (float): MLP隐藏层特征维度与嵌入层特征维度的比.
  11. qkv_bias (bool, optional): 给 query, key, value 添加一个可学习偏置,默认为 True.
  12. qk_scale (float | None, optional): 重写默认的缩放因子 scale.
  13. drop (float, optional): 随机丢弃神经元,丢弃率默认为 0.0.
  14. attn_drop (float, optional): 注意力图随机丢弃率,默认为 0.0.
  15. drop_path (float | tuple[float], optional): 深度随机丢弃率,默认为 0.0.
  16. norm_layer (nn.Module, optional): 归一化操作,默认为 nn.LayerNorm.
  17. downsample (nn.Module | None, optional): 结尾处的下采样层,默认没有.
  18. use_checkpoint (bool): 是否使用 checkpointing 来节省显存,默认为 False.
  19. """
  20. def __init__(self, dim, input_resolution, depth, num_heads, window_size,
  21. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
  22. drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
  23. super().__init__()
  24. self.dim = dim # 输入特征的维度
  25. self.input_resolution = input_resolution # 输入分辨率
  26. self.depth = depth # SWin Transformer 块的个数
  27. self.use_checkpoint = use_checkpoint # 是否使用 checkpointing 来节省显存,默认为 False
  28. # 创建 Swin Transformer 网络
  29. self.blocks = nn.ModuleList([
  30. SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
  31. num_heads=num_heads, window_size=window_size,
  32. shift_size=0 if (i % 2 == 0) else window_size // 2,
  33. mlp_ratio=mlp_ratio,
  34. qkv_bias=qkv_bias, qk_scale=qk_scale,
  35. drop=drop, attn_drop=attn_drop,
  36. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  37. norm_layer=norm_layer)
  38. for i in range(depth)])
  39. # patch 合并层
  40. if downsample is not None: # 如果有下采样
  41. self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) # 下采样
  42. else:
  43. self.downsample = None # 不做下采样
  44. #定义前向传播
  45. def forward(self, x, x_size):
  46. for blk in self.blocks: # x 输入串联的 Swin Transformer 块
  47. if self.use_checkpoint:
  48. x = checkpoint.checkpoint(blk, x, x_size) # 使用 checkpoint
  49. else:
  50. x = blk(x, x_size) # 直接输入网络
  51. if self.downsample is not None:
  52. x = self.downsample(x) # 下采样
  53. return x

(3)残差 Swin Transformer块(RSTB)

  1. # 残差 Swin Transforme 块 (RSTB)
  2. class RSTB(nn.Module):
  3. """
  4. 输入:
  5. dim (int): 输入特征的维度.
  6. input_resolution (tuple[int]): 输入分辨率.
  7. depth (int): SWin Transformer 块的个数.
  8. num_heads (int): 注意力头的个数.
  9. window_size (int): 本地(当前块中)窗口的大小.
  10. mlp_ratio (float): MLP隐藏层特征维度与嵌入层特征维度的比.
  11. qkv_bias (bool, optional): 给 query, key, value 添加一个可学习偏置,默认为 True.
  12. qk_scale (float | None, optional): 重写默认的缩放因子 scale.
  13. drop (float, optional): D 随机丢弃神经元,丢弃率默认为 0.0.
  14. attn_drop (float, optional): 注意力图随机丢弃率,默认为 0.0.
  15. drop_path (float | tuple[float], optional): 深度随机丢弃率,默认为 0.0.
  16. norm_layer (nn.Module, optional): 归一化操作,默认为 nn.LayerNorm.
  17. downsample (nn.Module | None, optional): 结尾处的下采样层,默认没有.
  18. use_checkpoint (bool): 是否使用 checkpointing 来节省显存,默认为 False.
  19. img_size: 输入图片的大小.
  20. patch_size: Patch 的大小.
  21. resi_connection: 残差连接之前的卷积块.
  22. """
  23. def __init__(self, dim, input_resolution, depth, num_heads, window_size,
  24. mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
  25. drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
  26. img_size=224, patch_size=4, resi_connection='1conv'):
  27. super(RSTB, self).__init__()
  28. self.dim = dim # 输入特征的维度
  29. self.input_resolution = input_resolution # 输入分辨率
  30. # SWin Transformer 基础层
  31. self.residual_group = BasicLayer(dim=dim,
  32. input_resolution=input_resolution,
  33. depth=depth,
  34. num_heads=num_heads,
  35. window_size=window_size,
  36. mlp_ratio=mlp_ratio,
  37. qkv_bias=qkv_bias, qk_scale=qk_scale,
  38. drop=drop, attn_drop=attn_drop,
  39. drop_path=drop_path,
  40. norm_layer=norm_layer,
  41. downsample=downsample,
  42. use_checkpoint=use_checkpoint)
  43. if resi_connection == '1conv': # 结尾用 1 个卷积层
  44. self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
  45. elif resi_connection == '3conv': # 结尾用 3 个卷积层
  46. # 为了减少参数使用和节约显存,采用瓶颈结构
  47. self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
  48. nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
  49. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  50. nn.Conv2d(dim // 4, dim, 3, 1, 1))
  51. # 图像转成 Patch Embeddings
  52. self.patch_embed = PatchEmbed(
  53. img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
  54. norm_layer=None)
  55. # 从 Patch Embeddings 组合图像
  56. self.patch_unembed = PatchUnEmbed(
  57. img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
  58. norm_layer=None)
  59. # 定义前向传播
  60. def forward(self, x, x_size):
  61. return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x # 引入残差

6、高质量图像重建(HQ Image Reconstruction)

高质量图像重建模块是卷积和上采样操作的组合,论文种提出4种结构。

(1)经典超分(卷积 + pixelshuffle 上采样 + 卷积);

(2)轻量超分(卷积 + pixelshuffle 上采样);

(3)真实图像超分(卷积 + 卷积插值上采样 + 卷积插值上采样 + 卷积);

(4)图像去噪和JPEG压缩去伪影(卷积 + 引入残差)。

在这里主要看一下两种上采样操作:

(1)先卷积再使用pixelshuffle上采样,特征图维度不是3

  1. # 上采样
  2. class Upsample(nn.Sequential):
  3. """
  4. 输入:
  5. scale (int): 缩放因子,支持 2^n and 3.
  6. num_feat (int): 中间特征的通道数.
  7. """
  8. def __init__(self, scale, num_feat):
  9. m = []
  10. if (scale & (scale - 1)) == 0: # 缩放因子等于 2^n
  11. for _ in range(int(math.log(scale, 2))): # 循环 n 次
  12. m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) # 卷积层
  13. m.append(nn.PixelShuffle(2)) # pixelshuffle 上采样 2 倍
  14. elif scale == 3: # 缩放因子等于 3
  15. m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) # 卷积层
  16. m.append(nn.PixelShuffle(3)) # pixelshuffle 上采样 3 倍
  17. else:
  18. # 报错,缩放因子不对
  19. raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
  20. super(Upsample, self).__init__(*m)

(2)一步既实现上采样也实现输出降维,特征图维度是3,即最后的恢复图像

  1. # 一步实现既上采样也降维
  2. class UpsampleOneStep(nn.Sequential):
  3. """一步上采样与前边上采样模块不同之处在于该模块只有一个卷积层和一个 pixelshuffle 层
  4. 输入:
  5. scale (int): 缩放因子,支持 2^n and 3.
  6. num_feat (int): 中间特征的通道数.
  7. """
  8. def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
  9. self.num_feat = num_feat # 中间特征的通道数
  10. self.input_resolution = input_resolution # 输入分辨率
  11. m = []
  12. m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) # 卷积层
  13. m.append(nn.PixelShuffle(scale)) # pixelshuffle 上采样 scale 倍
  14. super(UpsampleOneStep, self).__init__(*m)
7、测试实例

虽然SwinIR 的整体参数不大,但是计算负担比较大。

  1. upscale = 4 # 图像放大因子
  2. window_size = 8 # 窗口大小
  3. height = (1024 // upscale // window_size + 1) * window_size # 输入图像的高
  4. width = (720 // upscale // window_size + 1) * window_size # 输入图像的宽
  5. # 实例化 SWinIR
  6. model = SwinIR(upscale=2, img_size=(height, width),
  7. window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
  8. embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
  9. print(model) # 打印网络结构
  10. print(height, width) # 打印输入图像的高和宽
  11. x = torch.randn((1, 3, height, width)) #随机生成输入图像
  12. x = model(x) # 送入网络
  13. print(x.shape) # 打印网络输入的图像结构

总结

论文提出RSTB进行深度特征提取,每个RSTB由Swin变换层、卷积层和残差连接组成。大量实验表明,SwinIR在经典图像复原、轻量级图像复原、真实图像复原、灰度图像去噪、彩色图像去噪和JPEG压缩伪影减少等6种不同设置下,均取得了令人满意的效果,从而验证了SwinIR的有效性和普适性。在未来,还将把该模型扩展到其他恢复任务,如图像去模糊和去噪。

IPT仿照VIT,把Transformer运用到了图像处理任务种。Transformer在视觉领域魔改至今,Swin Transformer当属其中最优、运用最多的变体。因此SwinIR进一步把Swin Transformer中的block搬到了图像处理任务中,模型则仍然遵循目前超分网络中的head+body+tail的通用结构,改进相对较小。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/238604
推荐阅读
相关标签
  

闽ICP备14008679号