赞
踩
DexiNed 由密集极端初始网络(Dexi)和上采样块(UB)组成,其中上采样块是DexiNed 用于边缘细化的关键组件,来自 Dexi 块的每个输出都馈送到 UB。
DexiNed主要有6个块的编码器形成,可以参考Xception,蓝色块由两个卷积层形成,3×3卷积核——BN——ReLU,max-pooling为核大小是3×3,stride是2.整体构造遵循多尺度,对应会有上采样流程。
UB 由条件堆叠子块形成。 每一个子块有 2 层,一层是卷积层,另外一层是反卷积层;有两种类型的子块。
MindSpore实现:
"""DexiNed 网络结构"""
def weight_init(net):
for name, param in net.parameters_and_names():
if 'weight' in name:
param.set_data(
init.initializer(
init.XavierNormal(),
param.shape,
param.dtype))
if 'bias' in name:
param.set_data(init.initializer('zeros', param.shape, param.dtype))
class CoFusion(nn.Cell):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv1 = nn.Conv2d(
in_ch, 64, kernel_size=3,
stride=1, padding=1, has_bias=True,
pad_mode="pad", weight_init=init.XavierNormal())
self.conv2 = nn.Conv2d(
64, 64, kernel_size=3,
stride=1, padding=1, has_bias=True,
pad_mode="pad", weight_init=init.XavierNormal())
self.conv3 = nn.Conv2d(
64, out_ch, kernel_size=3,
stride=1, padding=1, has_bias=True,
pad_mode="pad", weight_init=init.XavierNormal())
self.relu = nn.ReLU()
self.norm_layer1 = nn.GroupNorm(4, 64)
self.norm_layer2 = nn.GroupNorm(4, 64)
def construct(self, x):
attn = self.relu(self.norm_layer1(self.conv1(x)))
attn = self.relu(self.norm_layer2(self.conv2(attn)))
attn = ops.softmax(self.conv3(attn), axis=1)
return ((x * attn).sum(1)).expand_dims(1)
class _DenseLayer(nn.Cell):
def __init__(self, input_features, out_features):
super(_DenseLayer, self).__init__()
self.conv1 = nn.Conv2d(
input_features, out_features, kernel_size=3,
stride=1, padding=2, pad_mode="pad",
has_bias=True, weight_init=init.XavierNormal())
self.norm1 = nn.BatchNorm2d(out_features)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(
out_features, out_features, kernel_size=3,
stride=1, pad_mode="pad", has_bias=True,
weight_init=init.XavierNormal())
self.norm2 = nn.BatchNorm2d(out_features)
self.relu = ops.ReLU()
def construct(self, x):
x1, x2 = x
x1 = self.conv1(self.relu(x1))
x1 = self.norm1(x1)
x1 = self.relu1(x1)
x1 = self.conv2(x1)
new_features = self.norm2(x1)
return 0.5 * (new_features + x2), x2
class _DenseBlock(nn.Cell):
def __init__(self, num_layers, input_features, out_features):
super(_DenseBlock, self).__init__()
self.denselayer1 = _DenseLayer(input_features, out_features)
input_features = out_features
self.denselayer2 = _DenseLayer(input_features, out_features)
if num_layers == 3:
self.denselayer3 = _DenseLayer(input_features, out_features)
self.layers = nn.SequentialCell(
[self.denselayer1, self.denselayer2, self.denselayer3])
else:
self.layers = nn.SequentialCell(
[self.denselayer1, self.denselayer2])
def construct(self, x):
x = self.layers(x)
return x
class UpConvBlock(nn.Cell):
def __init__(self, in_features, up_scale):
super(UpConvBlock, self).__init__()
self.up_factor = 2
self.constant_features = 16
layers = self.make_deconv_layers(in_features, up_scale)
assert layers is not None, layers
self.features = nn.SequentialCell(*layers)
def make_deconv_layers(self, in_features, up_scale):
layers = []
all_pads = [0, 0, 1, 3, 7]
for i in range(up_scale):
kernel_size = 2 ** up_scale
pad = all_pads[up_scale] # kernel_size-1
out_features = self.compute_out_features(i, up_scale)
layers.append(nn.Conv2d(
in_features, out_features,
1, has_bias=True))
layers.append(nn.ReLU())
layers.append(nn.Conv2dTranspose(
out_features, out_features, kernel_size,
stride=2, padding=pad, pad_mode="pad",
has_bias=True, weight_init=init.XavierNormal()))
in_features = out_features
return layers
def compute_out_features(self, idx, up_scale):
return 1 if idx == up_scale - 1 else self.constant_features
def construct(self, x):
return self.features(x)
class SingleConvBlock(nn.Cell):
def __init__(self, in_features, out_features, stride,
use_bs=True
):
super().__init__()
self.use_bn = use_bs
self.conv = nn.Conv2d(
in_features,
out_features,
1,
stride=stride,
pad_mode="pad",
has_bias=True,
weight_init=init.XavierNormal())
self.bn = nn.BatchNorm2d(out_features)
def construct(self, x):
x = self.conv(x)
if self.use_bn:
x = self.bn(x)
return x
class DoubleConvBlock(nn.Cell):
def __init__(self, in_features, mid_features,
out_features=None,
stride=1,
use_act=True):
super(DoubleConvBlock, self).__init__()
self.use_act = use_act
if out_features is None:
out_features = mid_features
self.conv1 = nn.Conv2d(
in_features,
mid_features,
3,
padding=1,
stride=stride,
pad_mode="pad",
has_bias=True,
weight_init=init.XavierNormal())
self.bn1 = nn.BatchNorm2d(mid_features)
self.conv2 = nn.Conv2d(
mid_features,
out_features,
3,
padding=1,
pad_mode="pad",
has_bias=True,
weight_init=init.XavierNormal())
self.bn2 = nn.BatchNorm2d(out_features)
self.relu = nn.ReLU()
def construct(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
if self.use_act:
x = self.relu(x)
return x
class maxpooling(nn.Cell):
def __init__(self):
super(maxpooling, self).__init__()
self.pad = nn.Pad(((0,0),(0,0),(1,1),(1,1)), mode="SYMMETRIC")
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
def construct(self, x):
x = self.pad(x)
x = self.maxpool(x)
return x
class DexiNed(nn.Cell):
def __init__(self):
super(DexiNed, self).__init__()
self.block_1 = DoubleConvBlock(3, 32, 64, stride=2,)
self.block_2 = DoubleConvBlock(64, 128, use_act=False)
self.dblock_3 = _DenseBlock(2, 128, 256) # [128,256,100,100]
self.dblock_4 = _DenseBlock(3, 256, 512)
self.dblock_5 = _DenseBlock(3, 512, 512)
self.dblock_6 = _DenseBlock(3, 512, 256)
self.maxpool = maxpooling()
self.side_1 = SingleConvBlock(64, 128, 2)
self.side_2 = SingleConvBlock(128, 256, 2)
self.side_3 = SingleConvBlock(256, 512, 2)
self.side_4 = SingleConvBlock(512, 512, 1)
self.side_5 = SingleConvBlock(
512, 256, 1) # Sory I forget to comment this line :(
# right skip connections, figure in Journal paper
self.pre_dense_2 = SingleConvBlock(128, 256, 2)
self.pre_dense_3 = SingleConvBlock(128, 256, 1)
self.pre_dense_4 = SingleConvBlock(256, 512, 1)
self.pre_dense_5 = SingleConvBlock(512, 512, 1)
self.pre_dense_6 = SingleConvBlock(512, 256, 1)
self.up_block_1 = UpConvBlock(64, 1)
self.up_block_2 = UpConvBlock(128, 1)
self.up_block_3 = UpConvBlock(256, 2)
self.up_block_4 = UpConvBlock(512, 3)
self.up_block_5 = UpConvBlock(512, 4)
self.up_block_6 = UpConvBlock(256, 4)
self.block_cat = SingleConvBlock(6, 1, stride=1, use_bs=False)
def slice(self, tensor, slice_shape):
t_shape = tensor.shape
height, width = slice_shape
if t_shape[-1] != slice_shape[-1]:
new_tensor = ops.interpolate(
tensor,
sizes=(height, width),
mode='bilinear',
coordinate_transformation_mode="half_pixel")
else:
new_tensor = tensor
return new_tensor
def construct(self, x):
assert x.ndim == 4, x.shape
# Block 1
block_1 = self.block_1(x)
block_1_side = self.side_1(block_1)
# Block 2
block_2 = self.block_2(block_1)
block_2_down = self.maxpool(block_2)
block_2_add = block_2_down + block_1_side
block_2_side = self.side_2(block_2_add)
# Block 3
block_3_pre_dense = self.pre_dense_3(block_2_down)
block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
block_3_down = self.maxpool(block_3) # [128,256,50,50]
block_3_add = block_3_down + block_2_side
block_3_side = self.side_3(block_3_add)
# Block 4
block_2_resize_half = self.pre_dense_2(block_2_down)
block_4_pre_dense = self.pre_dense_4(
block_3_down + block_2_resize_half)
block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
block_4_down = self.maxpool(block_4)
block_4_add = block_4_down + block_3_side
block_4_side = self.side_4(block_4_add)
# Block 5
block_5_pre_dense = self.pre_dense_5(
block_4_down) # block_5_pre_dense_512 +block_4_down
block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
block_5_add = block_5 + block_4_side
# Block 6
block_6_pre_dense = self.pre_dense_6(block_5)
block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])
# upsampling blocks
out_1 = self.up_block_1(block_1)
out_2 = self.up_block_2(block_2)
out_3 = self.up_block_3(block_3)
out_4 = self.up_block_4(block_4)
out_5 = self.up_block_5(block_5)
out_6 = self.up_block_6(block_6)
results = [out_1, out_2, out_3, out_4, out_5, out_6]
# concatenate multiscale outputs
op = ops.Concat(1)
block_cat = op(results)
block_cat = self.block_cat(block_cat) # Bx1xHxW
results.append(block_cat)
return results
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。