赞
踩
论文题目:U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection
论文链接:https://arxiv.org/abs/2005.09007
- class RSU(nn.Module):
- def __init__(self, height: int, in_ch: int, mid_ch: int, out_ch: int):
- super().__init__()
-
- assert height >= 2 #断言
- self.conv_in = ConvBNReLU(in_ch, out_ch) # 第一个ConvBNReLU
-
- # 构建列表encode_list
- encode_list = [DownConvBNReLU(out_ch, mid_ch, flag=False)] #第一个encode的ConvBNReLU,没有下采样
- # 构建列表decode_list
- decode_list = [UpConvBNReLU(mid_ch * 2, mid_ch, flag=False)] #第一个decode的ConvBNReLU,没有上采样
- for i in range(height - 2):
- encode_list.append(DownConvBNReLU(mid_ch, mid_ch)) #encode中有下采样的ConvBNReLU
- #decode中有上采样的ConvBNReLU,最后一个ConvBNReLU的输出通道数为out_ch,输入通道为mid_ch * 2是因为concat
- decode_list.append(UpConvBNReLU(mid_ch * 2, mid_ch if i < height - 3 else out_ch))
-
- encode_list.append(ConvBNReLU(mid_ch, mid_ch, dilation=2)) #最下面一个含膨胀卷积 encode的ConvBNReLU,将其添加进列表encode_list
- self.encode_modules = nn.ModuleList(encode_list)
- self.decode_modules = nn.ModuleList(decode_list)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_in = self.conv_in(x)
-
- x = x_in
- encode_outputs = []
- for m in self.encode_modules:
- x = m(x)
- encode_outputs.append(x) # 依次通过encode_modules,并把每一层的输出添加进列表encode_outputs
-
- x = encode_outputs.pop() # 最后一个encode的ConvBNReLU输出,作为decode的输入
- for m in self.decode_modules:
- x2 = encode_outputs.pop() #依次弹出
- x = m(x, x2) # 拼接后在传入decode_modules
-
- return x + x_in #最后的输出与第一个ConvBNReLU的结果进行add操作
- class RSU4F(nn.Module):
- def __init__(self, in_ch: int, mid_ch: int, out_ch: int):
- super().__init__()
- self.conv_in = ConvBNReLU(in_ch, out_ch)
- # encode模块,包含四个ConvBNReLU
- self.encode_modules = nn.ModuleList([ConvBNReLU(out_ch, mid_ch),
- ConvBNReLU(mid_ch, mid_ch, dilation=2),
- ConvBNReLU(mid_ch, mid_ch, dilation=4),
- ConvBNReLU(mid_ch, mid_ch, dilation=8)])
-
- # decode模块,包含三个ConvBNReLU
- self.decode_modules = nn.ModuleList([ConvBNReLU(mid_ch * 2, mid_ch, dilation=4),
- ConvBNReLU(mid_ch * 2, mid_ch, dilation=2),
- ConvBNReLU(mid_ch * 2, out_ch)])
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_in = self.conv_in(x)
-
- x = x_in
- encode_outputs = []
- for m in self.encode_modules:
- x = m(x)
- encode_outputs.append(x) # 依次通过encode_modules,并把每一层的输出添加进列表encode_outputs
-
- x = encode_outputs.pop() # 最后一个encode的ConvBNReLU输出,作为decode的输入
- for m in self.decode_modules:
- x2 = encode_outputs.pop()
- # 没有用到UpConvBNReLU模块,故需要torch.cat([x, x2]
- x = m(torch.cat([x, x2], dim=1))
-
- return x + x_in #最后的输出与第一个ConvBNReLU的结果进行add操作
- from typing import Union, List
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
-
- class ConvBNReLU(nn.Module):
- def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1): #dilation=1代表普通卷积,大于1代表膨胀卷积
- super().__init__()
-
- padding = kernel_size // 2 if dilation == 1 else dilation
- self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding, dilation=dilation, bias=False)
- self.bn = nn.BatchNorm2d(out_ch)
- self.relu = nn.ReLU(inplace=True)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return self.relu(self.bn(self.conv(x)))
-
-
- class DownConvBNReLU(ConvBNReLU): #继承自ConvBNReLU
- def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1, flag: bool = True): #flag表示是否启用下采样
- super().__init__(in_ch, out_ch, kernel_size, dilation) #调用父类的方法
- self.down_flag = flag
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.down_flag: #如果flag为True,则启用下采样
- x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True)
- return self.relu(self.bn(self.conv(x)))
-
-
- class UpConvBNReLU(ConvBNReLU): #继承自ConvBNReLU
- def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1, flag: bool = True): #flag表示是否启用上采样
- super().__init__(in_ch, out_ch, kernel_size, dilation)
- self.up_flag = flag
-
- def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
- if self.up_flag: #如果flag为True,则启用上采样
- x1 = F.interpolate(x1, size=x2.shape[2:], mode='bilinear', align_corners=False) #x1通过双线性插值上采用到x2的宽高
- return self.relu(self.bn(self.conv(torch.cat([x1, x2], dim=1)))) #x1, x2进行拼接,再ConvBNReLU
-
-
- class RSU(nn.Module):
- def __init__(self, height: int, in_ch: int, mid_ch: int, out_ch: int):
- super().__init__()
-
- assert height >= 2 #断言
- self.conv_in = ConvBNReLU(in_ch, out_ch) # 第一个ConvBNReLU
-
- # 构建列表encode_list
- encode_list = [DownConvBNReLU(out_ch, mid_ch, flag=False)] #第一个encode的ConvBNReLU,没有下采样
- # 构建列表decode_list
- decode_list = [UpConvBNReLU(mid_ch * 2, mid_ch, flag=False)] #第一个decode的ConvBNReLU,没有上采样
- for i in range(height - 2):
- encode_list.append(DownConvBNReLU(mid_ch, mid_ch)) #encode中有下采样的ConvBNReLU
- #decode中有上采样的ConvBNReLU,最后一个ConvBNReLU的输出通道数为out_ch,输入通道为mid_ch * 2是因为concat
- decode_list.append(UpConvBNReLU(mid_ch * 2, mid_ch if i < height - 3 else out_ch))
-
- encode_list.append(ConvBNReLU(mid_ch, mid_ch, dilation=2)) #最下面一个含膨胀卷积 encode的ConvBNReLU,将其添加进列表encode_list
- self.encode_modules = nn.ModuleList(encode_list)
- self.decode_modules = nn.ModuleList(decode_list)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_in = self.conv_in(x)
-
- x = x_in
- encode_outputs = []
- for m in self.encode_modules:
- x = m(x)
- encode_outputs.append(x) # 依次通过encode_modules,并把每一层的输出添加进列表encode_outputs
-
- x = encode_outputs.pop() # 最后一个encode的ConvBNReLU输出,作为decode的输入
- for m in self.decode_modules:
- x2 = encode_outputs.pop() #依次弹出
- # UpConvBNReLU已经包含torch.cat([x1, x2]
- x = m(x, x2) # 拼接后在传入decode_modules
-
- return x + x_in #最后的输出与第一个ConvBNReLU的结果进行add操作
-
-
- class RSU4F(nn.Module):
- def __init__(self, in_ch: int, mid_ch: int, out_ch: int):
- super().__init__()
- self.conv_in = ConvBNReLU(in_ch, out_ch)
- # encode模块,包含四个ConvBNReLU
- self.encode_modules = nn.ModuleList([ConvBNReLU(out_ch, mid_ch),
- ConvBNReLU(mid_ch, mid_ch, dilation=2),
- ConvBNReLU(mid_ch, mid_ch, dilation=4),
- ConvBNReLU(mid_ch, mid_ch, dilation=8)])
-
- # decode模块,包含三个ConvBNReLU
- self.decode_modules = nn.ModuleList([ConvBNReLU(mid_ch * 2, mid_ch, dilation=4),
- ConvBNReLU(mid_ch * 2, mid_ch, dilation=2),
- ConvBNReLU(mid_ch * 2, out_ch)])
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x_in = self.conv_in(x)
-
- x = x_in
- encode_outputs = []
- for m in self.encode_modules:
- x = m(x)
- encode_outputs.append(x) # 依次通过encode_modules,并把每一层的输出添加进列表encode_outputs
-
- x = encode_outputs.pop() # 最后一个encode的ConvBNReLU输出,作为decode的输入
- for m in self.decode_modules:
- x2 = encode_outputs.pop()
- # 没有用到UpConvBNReLU模块,故需要torch.cat([x, x2]
- x = m(torch.cat([x, x2], dim=1))
-
- return x + x_in #最后的输出与第一个ConvBNReLU的结果进行add操作
-
-
- class U2Net(nn.Module):
- def __init__(self, cfg: dict, out_ch: int = 1): #此处的out_ch:是为1的
- super().__init__()
- assert "encode" in cfg
- assert "decode" in cfg
- self.encode_num = len(cfg["encode"]) #6
-
- encode_list = [] # 存储每个encode模块的输出
- side_list = [] # 收集每个decode和最后一个encode的输出
- for c in cfg["encode"]: #遍历encode
- # c: [height, in_ch, mid_ch, out_ch, RSU4F, side]
- assert len(c) == 6 # c中有六个参数
- # 如果是RSU,则传入height, in_ch, mid_ch, out_ch 如果是RSU4F,则传入in_ch, mid_ch, out_ch
- encode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))
-
- if c[5] is True: #有1个
- side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1)) #每层的输出后接一个3*3的卷积,此处的out_ch:是为1的
- self.encode_modules = nn.ModuleList(encode_list)
-
- decode_list = []
- for c in cfg["decode"]: #遍历decode
- # c: [height, in_ch, mid_ch, out_ch, RSU4F, side]
- assert len(c) == 6
- # 如果是RSU,则传入height, in_ch, mid_ch, out_ch 如果是RSU4F,则传入in_ch, mid_ch, out_ch
- decode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))
-
- if c[5] is True: #有5个
- side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1)) #每层的输出后接一个3*3的卷积,此处的out_ch:是为1的
- self.decode_modules = nn.ModuleList(decode_list)
- self.side_modules = nn.ModuleList(side_list) # encode和decode并在一起的
- self.out_conv = nn.Conv2d(self.encode_num * out_ch, out_ch, kernel_size=1) #concat之后跟的1*1卷积,输入通道为6,输出通道为1
-
- def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
- _, _, h, w = x.shape #获取输入图片的高宽
-
- # collect encode outputs
- encode_outputs = [] # 收集每个encode的输出
- for i, m in enumerate(self.encode_modules):
- x = m(x)
- encode_outputs.append(x) #放入encode_outputs列表中的是下采样之前的
- if i != self.encode_num - 1: #前五个encode层后都会进行下采样,最后一个encode后不会进行下采样
- x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True) #下采样之后再输入下一层
-
- # collect decode outputs
- x = encode_outputs.pop()
- decode_outputs = [x] #先收集En_6的输出
- for m in self.decode_modules:
- x2 = encode_outputs.pop() #弹出En_5的输出
- x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False) #最后一个encode的输出x,通过双线性插值上采用到x2的宽高
- x = m(torch.concat([x, x2], dim=1))
- decode_outputs.insert(0, x) # 将x插入在列表decode_outputs的最前面
- #最终列表decode_outputs = [De_1,De_2,De_3,De_4,De_5,En_6]
-
- # collect side outputs
- side_outputs = [] #用于存储通过3*3卷积之后的输出
- #side_modules构建的顺序是从En_6,De_5,De_4,De_3,De_2,De_1
- for m in self.side_modules:
- x = decode_outputs.pop()
- x = F.interpolate(m(x), size=[h, w], mode='bilinear', align_corners=False) #经过3*卷积核,通过双线性插值将其还原回输入图片的h, w
- side_outputs.insert(0, x) # 将x插入在列表side_outputs的最前面
- # 最终列表side_outputs = [Sup1,Sup2,Sup3,Sup4,Sup5,Sup6]
-
- x = self.out_conv(torch.concat(side_outputs, dim=1)) #六个特征图经concat拼接之后,在经过1*1卷积
-
- if self.training:
- # do not use torch.sigmoid for amp safe
- return [x] + side_outputs #训练模式,返回x和[Sup1,Sup2,Sup3,Sup4,Sup5,Sup6],在算损失的时候会用到
- else:
- return torch.sigmoid(x) #预测模式,直接sigmoid到 0~1
-
-
- def u2net_full(out_ch: int = 1):
- cfg = {
- # height为深度,side为是否收集其输出
- # height, in_ch, mid_ch, out_ch, RSU4F, side
- "encode": [[7, 3, 32, 64, False, False], # En1
- [6, 64, 32, 128, False, False], # En2
- [5, 128, 64, 256, False, False], # En3
- [4, 256, 128, 512, False, False], # En4
- [4, 512, 256, 512, True, False], # En5
- [4, 512, 256, 512, True, True]], # En6
- # height, in_ch, mid_ch, out_ch, RSU4F, side
- "decode": [[4, 1024, 256, 512, True, True], # De5
- [4, 1024, 128, 256, False, True], # De4
- [5, 512, 64, 128, False, True], # De3
- [6, 256, 32, 64, False, True], # De2
- [7, 128, 16, 64, False, True]] # De1
- }
-
- return U2Net(cfg, out_ch)
-
-
- def u2net_lite(out_ch: int = 1):
- cfg = {
- # height, in_ch, mid_ch, out_ch, RSU4F, side
- "encode": [[7, 3, 16, 64, False, False], # En1
- [6, 64, 16, 64, False, False], # En2
- [5, 64, 16, 64, False, False], # En3
- [4, 64, 16, 64, False, False], # En4
- [4, 64, 16, 64, True, False], # En5
- [4, 64, 16, 64, True, True]], # En6
- # height, in_ch, mid_ch, out_ch, RSU4F, side
- "decode": [[4, 128, 16, 64, True, True], # De5
- [4, 128, 16, 64, False, True], # De4
- [5, 128, 16, 64, False, True], # De3
- [6, 128, 16, 64, False, True], # De2
- [7, 128, 16, 64, False, True]] # De1
- }
-
- return U2Net(cfg, out_ch)
-
-
- def convert_onnx(m, save_path):
- m.eval()
- x = torch.rand(1, 3, 288, 288, requires_grad=True)
-
- # export the model
- torch.onnx.export(m, # model being run
- x, # model input (or a tuple for multiple inputs)
- save_path, # where to save the model (can be a file or file-like object)
- export_params=True,
- opset_version=11)
-
-
- if __name__ == '__main__':
- # n_m = RSU(height=7, in_ch=3, mid_ch=12, out_ch=3)
- # convert_onnx(n_m, "RSU7.onnx")
- #
- # n_m = RSU4F(in_ch=3, mid_ch=12, out_ch=3)
- # convert_onnx(n_m, "RSU4F.onnx")
-
- u2net = u2net_full()
- convert_onnx(u2net, "u2net_full.onnx")
bulalalalalala~~~作者代码写的太妙了,看完源码后真的很佩服,点赞点赞!!!
bulalalalalala~~~废了好大劲才搞定!!!
【U2Net源码解析(Pytorch)】 https://www.bilibili.com/video/BV1Kt4y137iS?p=2&share_source=copy_web&vd_source=95705b32f23f70b32dfa1721628d5874
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。