赞
踩
卷积神经网络被大规模的应用在分类任务中,输出的结果是整个图像的类标签。但是UNet是像素级分类,输出的则是每个像素点的类别,且不同类别的像素会显示不同颜色,UNet常常用在生物医学图像上,而该任务中图片数据往往较少。所以,Ciresan等人训练了一个卷积神经网络,用滑动窗口提供像素的周围区域(patch)作为输入来预测每个像素的类标签。这个网络有两个优点:(1)输出结果可以定位出目标类别的位置;(2)由于输入的训练数据是patches,这样就相当于进行了数据增强,从而解决了生物医学图像数量少的问题。
(1)UNet采用全卷积神经网络。
(2)左边网络为特征提取网络:使用conv和pooling
(3)右边网络为特征融合网络:使用上采样产生的特征图与左侧特征图进行concatenate操作。(pooling层会丢失图像信息和降低图像分辨率且是永久性的,对于图像分割任务有一些影响,对图像分类任务的影响不大,为什么要做上采样呢?上采样可以让包含高级抽象特征低分辨率图片在保留高级抽象特征的同时变为高分辨率,然后再与左边低级表层特征高分辨率图片进行concatenate操作)
(4)最后再经过两次卷积操作,生成特征图,再用两个卷积核大小为1*1的卷积做分类得到最后的两张heatmap,例如第一张表示第一类的得分,第二张表示第二类的得分heatmap,然后作为softmax函数的输入,算出概率比较大的softmax,然后再进行loss,反向传播计算。
def get_unet():
inputs = Input((img_rows, img_cols, 1))
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
# pool1 = Dropout(0.25)(pool1)
# pool1 = BatchNormalization()(pool1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
conv2 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
# pool2 = Dropout(0.5)(pool2)
# pool2 = BatchNormalization()(pool2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool2)
conv3 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
# pool3 = Dropout(0.5)(pool3)
# pool3 = BatchNormalization()(pool3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool3)
conv4 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
# pool4 = Dropout(0.5)(pool4)
# pool4 = BatchNormalization()(pool4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool4)
conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
up6 = concatenate([Conv2DTranspose(256, (2, 2), strides=(
2, 2), padding='same')(conv5), conv4], axis=3)
# up6 = Dropout(0.5)(up6)
# up6 = BatchNormalization()(up6)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up6)
conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
up7 = concatenate([Conv2DTranspose(128, (2, 2), strides=(
2, 2), padding='same')(conv6), conv3], axis=3)
# up7 = Dropout(0.5)(up7)
# up7 = BatchNormalization()(up7)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up7)
conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
up8 = concatenate([Conv2DTranspose(64, (2, 2), strides=(
2, 2), padding='same')(conv7), conv2], axis=3)
# up8 = Dropout(0.5)(up8)
# up8 = BatchNormalization()(up8)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up8)
conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
up9 = concatenate([Conv2DTranspose(32, (2, 2), strides=(
2, 2), padding='same')(conv8), conv1], axis=3)
# up9 = Dropout(0.5)(up9)
# up9 = BatchNormalization()(up9)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(up9)
conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
# conv9 = Dropout(0.5)(conv9)
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
model = Model(inputs=[inputs], outputs=[conv10])
model.compile(optimizer=Adam(lr=1e-5),
loss=dice_coef_loss, metrics=[dice_coef])
return model
"""
这是根据UNet模型搭建出的一个基本网络结构
输入和输出大小是一样的,可以根据需求进行修改
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
# 基本卷积块
class Conv(nn.Module):
def __init__(self, C_in, C_out):
super(Conv, self).__init__()
self.layer = nn.Sequential(
nn.Conv2d(C_in, C_out, 3, 1, 1),
nn.BatchNorm2d(C_out),
# 防止过拟合
nn.Dropout(0.3),
nn.LeakyReLU(),
nn.Conv2d(C_out, C_out, 3, 1, 1),
nn.BatchNorm2d(C_out),
# 防止过拟合
nn.Dropout(0.4),
nn.LeakyReLU(),
)
def forward(self, x):
return self.layer(x)
# 下采样模块
class DownSampling(nn.Module):
def __init__(self, C):
super(DownSampling, self).__init__()
self.Down = nn.Sequential(
# 使用卷积进行2倍的下采样,通道数不变
nn.Conv2d(C, C, 3, 2, 1),
nn.LeakyReLU()
)
def forward(self, x):
return self.Down(x)
# 上采样模块
class UpSampling(nn.Module):
def __init__(self, C):
super(UpSampling, self).__init__()
# 特征图大小扩大2倍,通道数减半
self.Up = nn.Conv2d(C, C // 2, 1, 1)
def forward(self, x, r):
# 使用邻近插值进行下采样
up = F.interpolate(x, scale_factor=2, mode="nearest")
x = self.Up(up)
# 拼接,当前上采样的,和之前下采样过程中的
return torch.cat((x, r), 1)
# 主干网络
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
# 4次下采样
self.C1 = Conv(3, 64)
self.D1 = DownSampling(64)
self.C2 = Conv(64, 128)
self.D2 = DownSampling(128)
self.C3 = Conv(128, 256)
self.D3 = DownSampling(256)
self.C4 = Conv(256, 512)
self.D4 = DownSampling(512)
self.C5 = Conv(512, 1024)
# 4次上采样
self.U1 = UpSampling(1024)
self.C6 = Conv(1024, 512)
self.U2 = UpSampling(512)
self.C7 = Conv(512, 256)
self.U3 = UpSampling(256)
self.C8 = Conv(256, 128)
self.U4 = UpSampling(128)
self.C9 = Conv(128, 64)
self.Th = torch.nn.Sigmoid()
self.pred = torch.nn.Conv2d(64, 3, 3, 1, 1)
def forward(self, x):
# 下采样部分
R1 = self.C1(x)
R2 = self.C2(self.D1(R1))
R3 = self.C3(self.D2(R2))
R4 = self.C4(self.D3(R3))
Y1 = self.C5(self.D4(R4))
# 上采样部分
# 上采样的时候需要拼接起来
O1 = self.C6(self.U1(Y1, R4))
O2 = self.C7(self.U2(O1, R3))
O3 = self.C8(self.U3(O2, R2))
O4 = self.C9(self.U4(O3, R1))
# 输出预测,这里大小跟输入是一致的
# 可以把下采样时的中间抠出来再进行拼接,这样修改后输出就会更小
return self.Th(self.pred(O4))
if __name__ == '__main__':
a = torch.randn(2, 3, 256, 256)
net = UNet()
print(net(a).shape)
# -*-coding: utf-8 -*-
import tensorflow as tf
import tensorflow.contrib.slim as slim
def lrelu(x):
return tf.maximum(x * 0.2, x)
activation_fn=lrelu
def UNet(inputs, reg): # Unet
conv1 = slim.conv2d(inputs, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv1_1', weights_regularizer=reg)
conv1 = slim.conv2d(conv1, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv1_2',weights_regularizer=reg)
pool1 = slim.max_pool2d(conv1, [2, 2], padding='SAME')
conv2 = slim.conv2d(pool1, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv2_1',weights_regularizer=reg)
conv2 = slim.conv2d(conv2, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv2_2',weights_regularizer=reg)
pool2 = slim.max_pool2d(conv2, [2, 2], padding='SAME')
conv3 = slim.conv2d(pool2, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv3_1',weights_regularizer=reg)
conv3 = slim.conv2d(conv3, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv3_2',weights_regularizer=reg)
pool3 = slim.max_pool2d(conv3, [2, 2], padding='SAME')
conv4 = slim.conv2d(pool3, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv4_1',weights_regularizer=reg)
conv4 = slim.conv2d(conv4, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv4_2',weights_regularizer=reg)
pool4 = slim.max_pool2d(conv4, [2, 2], padding='SAME')
conv5 = slim.conv2d(pool4, 512, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv5_1',weights_regularizer=reg)
conv5 = slim.conv2d(conv5, 512, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv5_2',weights_regularizer=reg)
up6 = upsample_and_concat(conv5, conv4, 256, 512)
conv6 = slim.conv2d(up6, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv6_1',weights_regularizer=reg)
conv6 = slim.conv2d(conv6, 256, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv6_2',weights_regularizer=reg)
up7 = upsample_and_concat(conv6, conv3, 128, 256)
conv7 = slim.conv2d(up7, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv7_1',weights_regularizer=reg)
conv7 = slim.conv2d(conv7, 128, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv7_2',weights_regularizer=reg)
up8 = upsample_and_concat(conv7, conv2, 64, 128)
conv8 = slim.conv2d(up8, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv8_1',weights_regularizer=reg)
conv8 = slim.conv2d(conv8, 64, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv8_2',weights_regularizer=reg)
up9 = upsample_and_concat(conv8, conv1, 32, 64)
conv9 = slim.conv2d(up9, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv9_1', weights_regularizer=reg)
conv9 = slim.conv2d(conv9, 32, [3, 3], rate=1, activation_fn=activation_fn, scope='g_conv9_2',weights_regularizer=reg)
print("conv9.shape:{}".format(conv9.get_shape()))
type='UNet_1X'
with tf.variable_scope(name_or_scope="output"):
if type=='UNet_3X':#UNet放大三倍
conv10 = slim.conv2d(conv9, 27, [1, 1], rate=1, activation_fn=None, scope='g_conv10',weights_regularizer=reg)
out = tf.depth_to_space(conv10, 3)
if type=='UNet_1X':#输入输出维度相同
out = slim.conv2d(conv9, 6, [1, 1], rate=1, activation_fn=None, scope='g_conv10',weights_regularizer=reg)
return out
def upsample_and_concat(x1, x2, output_channels, in_channels):
pool_size = 2
deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02))
deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1])
deconv_output = tf.concat([deconv, x2], 3)
deconv_output.set_shape([None, None, None, output_channels * 2])
return deconv_output
if __name__=="__main__":
weight_decay=0.001
reg = slim.l2_regularizer(scale=weight_decay)
inputs = tf.ones(shape=[4, 256, 256, 3])
out=UNet(inputs,reg)
print("net1.shape:{}".format(inputs.get_shape()))
print("out.shape:{}".format(out.get_shape()))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。