赞
踩
参考:
[1] B站霹导
整个框架可以分为以下几个模块:
DoubleConv :
接收的参数包括 (in_c, out_c, mid_c = None)
; 首先需要判断是否有mid_c
,如果没有,则令其为mid_c=out_c
,一般来说,下采样部分是没有mid_c
的,而上采样部分有mid_c
。
主要包括的层有:
(1) nn.conv2d(in_c, mid_c,kernel_size=3,padding=1,bias=False)
(2) nn.BatchNorm2d(mid_c)
(3) nn.ReLU()
(4) nn.Conv2d(mid_c,out_c,kernel_size=3,padding=1,bias=False)
(5) nn.BatchNorm2d(out_c)
(6) nn.ReLU()
Down :
主要包括一个nn.MaxPool2d(kernel_size=2,stride=2)
,高宽缩小一半,深度不变;
Up :
使用双线性插值代替原论文中的转置卷积,并且DoubleConv
的参数为(in_c,out_c,in_c//2)
,上采样之后先进行concat操作,然后再通过doubleconv,定义和前向过程如下:
简洁版....
def __init__(self,in_c,out_c):
self.up = nn.Upsample(scale_factor=2, mode='bilinear',align_corners=True)
self.conv = DoubleConv(in_c,out_c,in_c//2)
def forward(self,x_1,x_2)
x_1 = self.up(x_1)
x_1 = torch.cat([x_1,x_2],dim=1) # 深度维拼接
x = self.conv(x_1)
return x
in_c
,输出维为类别数# This Python file uses the following encoding: utf-8 import torch import torch. nn as nn class DoubleConv(nn.Sequential): def __init__(self, in_c, out_c, mid_c=None): if mid_c is None: mid_c = out_c super().__init__( nn.Conv2d(in_c,mid_c,kernel_size=3,padding=1,bias=False), nn.BatchNorm(mid_c), nn.ReLU(), nn.Conv2d(mid_c,out_c,kernel_size=3,padding=1,bias=False), nn.BatchNorm(out_c), nn.ReLU() ) class Down(nn.Sequential): def __init__(self,in_c,out_c): super().__init__( nn.MaxPool2d(kernel_size=2,stride=2), DoubleConv(in_c,out_c) ) class Up(nn.Module): def __init__(self,in_c,out_c): super().__init__() self.up = nn.UpSample(scale_factor=2,mode='bilinear',align_corner=True) self.conv = DoubleConv(in_c,out_c,in_c//2) def forward(self,x1,x2): x1 = self.up(x1) x1 = torch.cat([x1,x2],dim=1) x = self.conv(x1) return x class Out(nn.Module): def __init__(self,in_c,num_c): super().__init__() self.out = nn.Conv2d(in_c,num_c,kernel_size=1) def forward(self,x): return self.out(x) class UNet(nn.Module): def __init__(self,in_c,num_c,base_c=64): self.in_conv = DoubleConv(in_c,base_c) # 3->64 self.down1 = Down(base_c,base_c*2) # 64->128 self.down2 = Down(base_c*2,base_c*4) # 128->256 self.down3 = Down(base_c*4,base_c*8) # 256->512 self.dowm4 = Down(base_c*8,base_c*8) # 512 still self.up1 = Up(base_c*16,base_c*4) # 输入的维度为concat之后的维度,既1024 -> 512 -> 256 self.up2 = Up(base_c*8, base_c*2) # 512-> 256-> 128 self.up3 = Up(base_c*4, base_c*1) # 256->128->64 self.up4 = Up(base_c*2, base_c) # 128->64->64 self.out = Out(base_c,num_c) def forward(self,x): x1 = self.in_conv(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5,x4) x = self.up2(x,x3) x = self.up3(x,x2) x = self.up4(x,x1) x = self.out(x) return x model = UNet(in_c=3,num_c=5,base_c=64) x = torch.rand(16,3,480,480) out = model(x) print(out.shape)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。