当前位置:   article > 正文

U2Net网络结构搭建_do not use torch.sigmoid for amp safe

do not use torch.sigmoid for amp safe

论文题目:U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection

论文链接:https://arxiv.org/abs/2005.09007

一、RSU-7模块

  1. class RSU(nn.Module):
  2. def __init__(self, height: int, in_ch: int, mid_ch: int, out_ch: int):
  3. super().__init__()
  4. assert height >= 2 #断言
  5. self.conv_in = ConvBNReLU(in_ch, out_ch) # 第一个ConvBNReLU
  6. # 构建列表encode_list
  7. encode_list = [DownConvBNReLU(out_ch, mid_ch, flag=False)] #第一个encode的ConvBNReLU,没有下采样
  8. # 构建列表decode_list
  9. decode_list = [UpConvBNReLU(mid_ch * 2, mid_ch, flag=False)] #第一个decode的ConvBNReLU,没有上采样
  10. for i in range(height - 2):
  11. encode_list.append(DownConvBNReLU(mid_ch, mid_ch)) #encode中有下采样的ConvBNReLU
  12. #decode中有上采样的ConvBNReLU,最后一个ConvBNReLU的输出通道数为out_ch,输入通道为mid_ch * 2是因为concat
  13. decode_list.append(UpConvBNReLU(mid_ch * 2, mid_ch if i < height - 3 else out_ch))
  14. encode_list.append(ConvBNReLU(mid_ch, mid_ch, dilation=2)) #最下面一个含膨胀卷积 encode的ConvBNReLU,将其添加进列表encode_list
  15. self.encode_modules = nn.ModuleList(encode_list)
  16. self.decode_modules = nn.ModuleList(decode_list)
  17. def forward(self, x: torch.Tensor) -> torch.Tensor:
  18. x_in = self.conv_in(x)
  19. x = x_in
  20. encode_outputs = []
  21. for m in self.encode_modules:
  22. x = m(x)
  23. encode_outputs.append(x) # 依次通过encode_modules,并把每一层的输出添加进列表encode_outputs
  24. x = encode_outputs.pop() # 最后一个encode的ConvBNReLU输出,作为decode的输入
  25. for m in self.decode_modules:
  26. x2 = encode_outputs.pop() #依次弹出
  27. x = m(x, x2) # 拼接后在传入decode_modules
  28. return x + x_in #最后的输出与第一个ConvBNReLU的结果进行add操作

 

二、RSU-4F模块

  1. class RSU4F(nn.Module):
  2. def __init__(self, in_ch: int, mid_ch: int, out_ch: int):
  3. super().__init__()
  4. self.conv_in = ConvBNReLU(in_ch, out_ch)
  5. # encode模块,包含四个ConvBNReLU
  6. self.encode_modules = nn.ModuleList([ConvBNReLU(out_ch, mid_ch),
  7. ConvBNReLU(mid_ch, mid_ch, dilation=2),
  8. ConvBNReLU(mid_ch, mid_ch, dilation=4),
  9. ConvBNReLU(mid_ch, mid_ch, dilation=8)])
  10. # decode模块,包含三个ConvBNReLU
  11. self.decode_modules = nn.ModuleList([ConvBNReLU(mid_ch * 2, mid_ch, dilation=4),
  12. ConvBNReLU(mid_ch * 2, mid_ch, dilation=2),
  13. ConvBNReLU(mid_ch * 2, out_ch)])
  14. def forward(self, x: torch.Tensor) -> torch.Tensor:
  15. x_in = self.conv_in(x)
  16. x = x_in
  17. encode_outputs = []
  18. for m in self.encode_modules:
  19. x = m(x)
  20. encode_outputs.append(x) # 依次通过encode_modules,并把每一层的输出添加进列表encode_outputs
  21. x = encode_outputs.pop() # 最后一个encode的ConvBNReLU输出,作为decode的输入
  22. for m in self.decode_modules:
  23. x2 = encode_outputs.pop()
  24. # 没有用到UpConvBNReLU模块,故需要torch.cat([x, x2]
  25. x = m(torch.cat([x, x2], dim=1))
  26. return x + x_in #最后的输出与第一个ConvBNReLU的结果进行add操作

 三、U2Net整体结构

  1. from typing import Union, List
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class ConvBNReLU(nn.Module):
  6. def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1): #dilation=1代表普通卷积,大于1代表膨胀卷积
  7. super().__init__()
  8. padding = kernel_size // 2 if dilation == 1 else dilation
  9. self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding, dilation=dilation, bias=False)
  10. self.bn = nn.BatchNorm2d(out_ch)
  11. self.relu = nn.ReLU(inplace=True)
  12. def forward(self, x: torch.Tensor) -> torch.Tensor:
  13. return self.relu(self.bn(self.conv(x)))
  14. class DownConvBNReLU(ConvBNReLU): #继承自ConvBNReLU
  15. def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1, flag: bool = True): #flag表示是否启用下采样
  16. super().__init__(in_ch, out_ch, kernel_size, dilation) #调用父类的方法
  17. self.down_flag = flag
  18. def forward(self, x: torch.Tensor) -> torch.Tensor:
  19. if self.down_flag: #如果flag为True,则启用下采样
  20. x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True)
  21. return self.relu(self.bn(self.conv(x)))
  22. class UpConvBNReLU(ConvBNReLU): #继承自ConvBNReLU
  23. def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1, flag: bool = True): #flag表示是否启用上采样
  24. super().__init__(in_ch, out_ch, kernel_size, dilation)
  25. self.up_flag = flag
  26. def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
  27. if self.up_flag: #如果flag为True,则启用上采样
  28. x1 = F.interpolate(x1, size=x2.shape[2:], mode='bilinear', align_corners=False) #x1通过双线性插值上采用到x2的宽高
  29. return self.relu(self.bn(self.conv(torch.cat([x1, x2], dim=1)))) #x1, x2进行拼接,再ConvBNReLU
  30. class RSU(nn.Module):
  31. def __init__(self, height: int, in_ch: int, mid_ch: int, out_ch: int):
  32. super().__init__()
  33. assert height >= 2 #断言
  34. self.conv_in = ConvBNReLU(in_ch, out_ch) # 第一个ConvBNReLU
  35. # 构建列表encode_list
  36. encode_list = [DownConvBNReLU(out_ch, mid_ch, flag=False)] #第一个encode的ConvBNReLU,没有下采样
  37. # 构建列表decode_list
  38. decode_list = [UpConvBNReLU(mid_ch * 2, mid_ch, flag=False)] #第一个decode的ConvBNReLU,没有上采样
  39. for i in range(height - 2):
  40. encode_list.append(DownConvBNReLU(mid_ch, mid_ch)) #encode中有下采样的ConvBNReLU
  41. #decode中有上采样的ConvBNReLU,最后一个ConvBNReLU的输出通道数为out_ch,输入通道为mid_ch * 2是因为concat
  42. decode_list.append(UpConvBNReLU(mid_ch * 2, mid_ch if i < height - 3 else out_ch))
  43. encode_list.append(ConvBNReLU(mid_ch, mid_ch, dilation=2)) #最下面一个含膨胀卷积 encode的ConvBNReLU,将其添加进列表encode_list
  44. self.encode_modules = nn.ModuleList(encode_list)
  45. self.decode_modules = nn.ModuleList(decode_list)
  46. def forward(self, x: torch.Tensor) -> torch.Tensor:
  47. x_in = self.conv_in(x)
  48. x = x_in
  49. encode_outputs = []
  50. for m in self.encode_modules:
  51. x = m(x)
  52. encode_outputs.append(x) # 依次通过encode_modules,并把每一层的输出添加进列表encode_outputs
  53. x = encode_outputs.pop() # 最后一个encode的ConvBNReLU输出,作为decode的输入
  54. for m in self.decode_modules:
  55. x2 = encode_outputs.pop() #依次弹出
  56. # UpConvBNReLU已经包含torch.cat([x1, x2]
  57. x = m(x, x2) # 拼接后在传入decode_modules
  58. return x + x_in #最后的输出与第一个ConvBNReLU的结果进行add操作
  59. class RSU4F(nn.Module):
  60. def __init__(self, in_ch: int, mid_ch: int, out_ch: int):
  61. super().__init__()
  62. self.conv_in = ConvBNReLU(in_ch, out_ch)
  63. # encode模块,包含四个ConvBNReLU
  64. self.encode_modules = nn.ModuleList([ConvBNReLU(out_ch, mid_ch),
  65. ConvBNReLU(mid_ch, mid_ch, dilation=2),
  66. ConvBNReLU(mid_ch, mid_ch, dilation=4),
  67. ConvBNReLU(mid_ch, mid_ch, dilation=8)])
  68. # decode模块,包含三个ConvBNReLU
  69. self.decode_modules = nn.ModuleList([ConvBNReLU(mid_ch * 2, mid_ch, dilation=4),
  70. ConvBNReLU(mid_ch * 2, mid_ch, dilation=2),
  71. ConvBNReLU(mid_ch * 2, out_ch)])
  72. def forward(self, x: torch.Tensor) -> torch.Tensor:
  73. x_in = self.conv_in(x)
  74. x = x_in
  75. encode_outputs = []
  76. for m in self.encode_modules:
  77. x = m(x)
  78. encode_outputs.append(x) # 依次通过encode_modules,并把每一层的输出添加进列表encode_outputs
  79. x = encode_outputs.pop() # 最后一个encode的ConvBNReLU输出,作为decode的输入
  80. for m in self.decode_modules:
  81. x2 = encode_outputs.pop()
  82. # 没有用到UpConvBNReLU模块,故需要torch.cat([x, x2]
  83. x = m(torch.cat([x, x2], dim=1))
  84. return x + x_in #最后的输出与第一个ConvBNReLU的结果进行add操作
  85. class U2Net(nn.Module):
  86. def __init__(self, cfg: dict, out_ch: int = 1): #此处的out_ch:是为1的
  87. super().__init__()
  88. assert "encode" in cfg
  89. assert "decode" in cfg
  90. self.encode_num = len(cfg["encode"]) #6
  91. encode_list = [] # 存储每个encode模块的输出
  92. side_list = [] # 收集每个decode和最后一个encode的输出
  93. for c in cfg["encode"]: #遍历encode
  94. # c: [height, in_ch, mid_ch, out_ch, RSU4F, side]
  95. assert len(c) == 6 # c中有六个参数
  96. # 如果是RSU,则传入height, in_ch, mid_ch, out_ch 如果是RSU4F,则传入in_ch, mid_ch, out_ch
  97. encode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))
  98. if c[5] is True: #有1个
  99. side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1)) #每层的输出后接一个3*3的卷积,此处的out_ch:是为1的
  100. self.encode_modules = nn.ModuleList(encode_list)
  101. decode_list = []
  102. for c in cfg["decode"]: #遍历decode
  103. # c: [height, in_ch, mid_ch, out_ch, RSU4F, side]
  104. assert len(c) == 6
  105. # 如果是RSU,则传入height, in_ch, mid_ch, out_ch 如果是RSU4F,则传入in_ch, mid_ch, out_ch
  106. decode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))
  107. if c[5] is True: #有5个
  108. side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1)) #每层的输出后接一个3*3的卷积,此处的out_ch:是为1的
  109. self.decode_modules = nn.ModuleList(decode_list)
  110. self.side_modules = nn.ModuleList(side_list) # encode和decode并在一起的
  111. self.out_conv = nn.Conv2d(self.encode_num * out_ch, out_ch, kernel_size=1) #concat之后跟的1*1卷积,输入通道为6,输出通道为1
  112. def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
  113. _, _, h, w = x.shape #获取输入图片的高宽
  114. # collect encode outputs
  115. encode_outputs = [] # 收集每个encode的输出
  116. for i, m in enumerate(self.encode_modules):
  117. x = m(x)
  118. encode_outputs.append(x) #放入encode_outputs列表中的是下采样之前的
  119. if i != self.encode_num - 1: #前五个encode层后都会进行下采样,最后一个encode后不会进行下采样
  120. x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True) #下采样之后再输入下一层
  121. # collect decode outputs
  122. x = encode_outputs.pop()
  123. decode_outputs = [x] #先收集En_6的输出
  124. for m in self.decode_modules:
  125. x2 = encode_outputs.pop() #弹出En_5的输出
  126. x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False) #最后一个encode的输出x,通过双线性插值上采用到x2的宽高
  127. x = m(torch.concat([x, x2], dim=1))
  128. decode_outputs.insert(0, x) # 将x插入在列表decode_outputs的最前面
  129. #最终列表decode_outputs = [De_1,De_2,De_3,De_4,De_5,En_6]
  130. # collect side outputs
  131. side_outputs = [] #用于存储通过3*3卷积之后的输出
  132. #side_modules构建的顺序是从En_6,De_5,De_4,De_3,De_2,De_1
  133. for m in self.side_modules:
  134. x = decode_outputs.pop()
  135. x = F.interpolate(m(x), size=[h, w], mode='bilinear', align_corners=False) #经过3*卷积核,通过双线性插值将其还原回输入图片的h, w
  136. side_outputs.insert(0, x) # 将x插入在列表side_outputs的最前面
  137. # 最终列表side_outputs = [Sup1,Sup2,Sup3,Sup4,Sup5,Sup6]
  138. x = self.out_conv(torch.concat(side_outputs, dim=1)) #六个特征图经concat拼接之后,在经过1*1卷积
  139. if self.training:
  140. # do not use torch.sigmoid for amp safe
  141. return [x] + side_outputs #训练模式,返回x和[Sup1,Sup2,Sup3,Sup4,Sup5,Sup6],在算损失的时候会用到
  142. else:
  143. return torch.sigmoid(x) #预测模式,直接sigmoid到 0~1
  144. def u2net_full(out_ch: int = 1):
  145. cfg = {
  146. # height为深度,side为是否收集其输出
  147. # height, in_ch, mid_ch, out_ch, RSU4F, side
  148. "encode": [[7, 3, 32, 64, False, False], # En1
  149. [6, 64, 32, 128, False, False], # En2
  150. [5, 128, 64, 256, False, False], # En3
  151. [4, 256, 128, 512, False, False], # En4
  152. [4, 512, 256, 512, True, False], # En5
  153. [4, 512, 256, 512, True, True]], # En6
  154. # height, in_ch, mid_ch, out_ch, RSU4F, side
  155. "decode": [[4, 1024, 256, 512, True, True], # De5
  156. [4, 1024, 128, 256, False, True], # De4
  157. [5, 512, 64, 128, False, True], # De3
  158. [6, 256, 32, 64, False, True], # De2
  159. [7, 128, 16, 64, False, True]] # De1
  160. }
  161. return U2Net(cfg, out_ch)
  162. def u2net_lite(out_ch: int = 1):
  163. cfg = {
  164. # height, in_ch, mid_ch, out_ch, RSU4F, side
  165. "encode": [[7, 3, 16, 64, False, False], # En1
  166. [6, 64, 16, 64, False, False], # En2
  167. [5, 64, 16, 64, False, False], # En3
  168. [4, 64, 16, 64, False, False], # En4
  169. [4, 64, 16, 64, True, False], # En5
  170. [4, 64, 16, 64, True, True]], # En6
  171. # height, in_ch, mid_ch, out_ch, RSU4F, side
  172. "decode": [[4, 128, 16, 64, True, True], # De5
  173. [4, 128, 16, 64, False, True], # De4
  174. [5, 128, 16, 64, False, True], # De3
  175. [6, 128, 16, 64, False, True], # De2
  176. [7, 128, 16, 64, False, True]] # De1
  177. }
  178. return U2Net(cfg, out_ch)
  179. def convert_onnx(m, save_path):
  180. m.eval()
  181. x = torch.rand(1, 3, 288, 288, requires_grad=True)
  182. # export the model
  183. torch.onnx.export(m, # model being run
  184. x, # model input (or a tuple for multiple inputs)
  185. save_path, # where to save the model (can be a file or file-like object)
  186. export_params=True,
  187. opset_version=11)
  188. if __name__ == '__main__':
  189. # n_m = RSU(height=7, in_ch=3, mid_ch=12, out_ch=3)
  190. # convert_onnx(n_m, "RSU7.onnx")
  191. #
  192. # n_m = RSU4F(in_ch=3, mid_ch=12, out_ch=3)
  193. # convert_onnx(n_m, "RSU4F.onnx")
  194. u2net = u2net_full()
  195. convert_onnx(u2net, "u2net_full.onnx")

bulalalalalala~~~作者代码写的太妙了,看完源码后真的很佩服,点赞点赞!!!

bulalalalalala~~~废了好大劲才搞定!!!

reference

【U2Net源码解析(Pytorch)】 https://www.bilibili.com/video/BV1Kt4y137iS?p=2&share_source=copy_web&vd_source=95705b32f23f70b32dfa1721628d5874

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/716006
推荐阅读
相关标签
  

闽ICP备14008679号