赞
踩
unet是一个语义分割模型,执行过程大体上也是先进行下采样再进行上采样,首先利用卷积进行下采样,提取多个不同下采样阶段的特征,然后再进行上采样,不同阶段的上采样与同一大小的特征图结合,最后得出种类图像。
模型的输入 shape:572,572,3
模型的输出 shape:388,388,n_classes
Encoder部分用于特征提取,一般对特征进行四次压缩,每次压缩后特征图的大小都减小一半。Encoder的主干网络使用MobileNet。
深度可分离卷积在tensorflow2中有两种实现方法,(DepthwiseConv2D + Conv1x1 ) 实现与(SeparableConv2D)实现
from tensorflow.keras.layers import * def conv_block(inputs, filters, kernel, strides): x = ZeroPadding2D(1)(inputs) x = Conv2D(filters, kernel, strides, padding='valid', use_bias=False)(x) x = BatchNormalization()(x) x = ReLU(max_value=6)(x) return x def dw_pw_block(inputs, dw_strides, pw_filters, name): x = ZeroPadding2D(1)(inputs) # dw x = DepthwiseConv2D((3, 3), dw_strides, padding='valid', use_bias=False, name=name)(x) x = BatchNormalization()(x) x = ReLU(max_value=6)(x) # pw x = Conv2D(pw_filters, (1, 1), 1, padding='valid', use_bias=False)(x) x = BatchNormalization()(x) x = ReLU(max_value=6)(x) return x # 基于 Mobilenet 的 segnet 编码器(DepthwiseConv2D + Conv1x1 实现) def segnet_encoder_MobilenetV1_1(height=416, width=416): img_input = Input(shape=(height, width, 3)) # block1:con1 + dw_pw_1 # 416,416,3 -- 208,208,32 -- 208,208,64 x = conv_block(img_input, 32, (3, 3), (2, 2)) x = dw_pw_block(x, 1, 64, 'dw_pw_1') # block2:dw_pw_2 # 208,208,64 -- 104,104,128 x = dw_pw_block(x, 2, 128, 'dw_pw_2_1') x = dw_pw_block(x, 1, 128, 'dw_pw_2_2') # block3:dw_pw_3 # 104,104,128 -- 52,52,256 x = dw_pw_block(x, 2, 256, 'dw_pw_3_1') x = dw_pw_block(x, 1, 256, 'dw_pw_3_2') # block4:dw_pw_4 # 52,52,256 -- 26,26,512 x = dw_pw_block(x, 2, 512, 'dw_pw_4_1') for i in range(5): x = dw_pw_block(x, 1, 512, 'dw_pw_4_' + str(i + 2)) out4 = x # block5:dw_pw_5 # 26,26,512 -- 13,13,1024 x = dw_pw_block(x, 2, 1024, 'dw_pw_5_1') x = dw_pw_block(x, 1, 1024, 'dw_pw_5_2') return img_input, out4
from tensorflow.keras.layers import * def sp_block(x, dw_strides, pw_filters, name): x = ZeroPadding2D(1)(x) x = SeparableConv2D(pw_filters, (3, 3), dw_strides, use_bias=False, name=name)(x) x = BatchNormalization()(x) x = ReLU(max_value=6)(x) return x # 基于 Mobilenet 的 segnet 编码器(SeparableConv2D实现) def segnet_encoder_MobilenetV1_2(height=416, width=416): img_input = Input(shape=(height, width, 3)) # block1:con1 + dw_pw_1 # 416,416,3 -- 208,208,32 -- 208,208,64 x = conv_block(img_input, 32, (3, 3), (2, 2)) x = sp_block(x, 1, 64, 'dw_pw_1') # block2:dw_pw_2 # 208,208,64 -- 104,104,128 x = sp_block(x, 2, 128, 'dw_pw_2_1') x = sp_block(x, 1, 128, 'dw_pw_2_2') # block3:dw_pw_3 # 104,104,128 -- 52,52,256 x = sp_block(x, 2, 256, 'dw_pw_3_1') x = sp_block(x, 1, 256, 'dw_pw_3_2') # block4:dw_pw_4 # 52,52,256 -- 26,26,512 x = sp_block(x, 2, 512, 'dw_pw_4_1') for i in range(5): x = sp_block(x, 1, 512, 'dw_pw_4_' + str(i + 2)) out4 = x # block5:dw_pw_5 # 26,26,512 -- 13,13,1024 x = sp_block(x, 2, 1024, 'dw_pw_5_1') x = sp_block(x, 1, 1024, 'dw_pw_5_2') return img_input, out4
from tensorflow.keras.layers import * from encoders import encoder_MobilenetV1_1,encoder_MobilenetV1_2 from tensorflow.keras.models import Model def zero_conv_bn(input,filters): x = ZeroPadding2D(1)(input) x = Conv2D(filters, 3)(x) x = BatchNormalization()(x) return x def build_unet(n_classes,input_height=416,input_width=416,encoder_type='MobilenetV1_1'): # 1.获取encoder的输出 (416,416,3--26,26,512) if encoder_type == 'MobilenetV1_1': img_input, [out1,out2,out3,out4,out5] = encoder_MobilenetV1_1(input_height, input_width) elif encoder_type == 'MobilenetV1_2': img_input, [out1,out2,out3,out4,out5] = encoder_MobilenetV1_2(input_height, input_width) else: raise RuntimeError('unet encoder name is error') # 26,26,512 -- 26,26,512 x = zero_conv_bn(out4, 512) # 26,26,512 -- 52,52,512 x = UpSampling2D((2,2))(x) # 52,52,512 + 52,52,256 -- 52,52,768 x = Concatenate()([x,out3]) # 52,52,768 -- 52,52,256 x = zero_conv_bn(x, 256) # 52,52,256 -- 104,104,256 x = UpSampling2D((2, 2))(x) # 104,104,256 + 104,104,128 -- 104,104,384 x = Concatenate()([x, out2]) # 104,104,384 -- 104,104,128 x = zero_conv_bn(x, 128) # 104,104,128 -- 208,208,128 x = UpSampling2D((2, 2))(x) # 208,208,128 + 208,208,64 -- 208,208,192 x = Concatenate()([x, out1]) # 208,208,192 -- 208,208,64 x = zero_conv_bn(x, 64) # 208,208,64 -- 208,208,n_classes x = Conv2D(n_classes,3,padding='same')(x) out = Reshape((int(input_height/2)*int(input_width/2),-1))(x) out = Softmax()(out) model = Model(img_input,out) return model
from unet import build_unet from tensorflow.keras.callbacks import ModelCheckpoint,ReduceLROnPlateau,EarlyStopping from tensorflow.keras.optimizers import Adam from tensorflow.keras.losses import BinaryCrossentropy,CategoricalCrossentropy import numpy as np from PIL import Image import os import argparse def parse_opt(): parse = argparse.ArgumentParser() parse.add_argument('--datasets_path',type=str,default='../../datasets/banmaxian',help='数据集路径') parse.add_argument('--n_classes',type=int,default=2,help='标签种类(含背景)') parse.add_argument('--height',type=int,default=416,help='图片高度') parse.add_argument('--width',type=int,default=416,help='图片宽度') parse.add_argument('--batch_size',type=int,default=2) parse.add_argument('--lr',type=float,default=0.0001) parse.add_argument('--epochs',type=int,default=50) parse.add_argument('--encoder_type',type=str,default='MobilenetV1_2',help='unet模型编码器的类型[MobilenetV1_1,MobilenetV1_2]') opt = parse.parse_args() return opt def get_data_from_file(opt): datasets_path,height,width,n_classes = opt.datasets_path,opt.height,opt.width,opt.n_classes with open(os.path.join(datasets_path,'train.txt')) as f: lines = f.readlines() lines = [line.replace('\n','') for line in lines] X = [] Y = [] for i in range(len(lines)): names = lines[i].split(';') real_name = names[0] # xx.jpg label_name = names[1] # xx.png # 读取真实图像 real_img = Image.open(os.path.join(datasets_path,'jpg',real_name)) real_img = real_img.resize((height,width)) real_img = np.array(real_img)/255 # (416,416,3) [0,1] X.append(real_img) # 读取标签图像,3通道,每个通道的数据都一样,每个像素点就是对应的类别,0表示背景 label_img = Image.open(os.path.join(datasets_path, 'png', label_name)) label_img = label_img.resize((int(height/2), int(width/2))) label_img = np.array(label_img) # (208,208,3) [0,1] # 根据标签图像来创建训练标签数据,n类对应的 seg_labels 就有n个通道 # 此时 seg_labels 每个通道的都值为 0 seg_labels = np.zeros((int(height/2), int(width/2),n_classes)) # (208,208,2) # 第0通道表示第0类 # 第1通道表示第1类 # ..... # 第n_classes通道表示第n_classes类 for c in range(n_classes): seg_labels[:,:,c] = (label_img[:,:,0]==c).astype(int) # 此时 seg_labels 每个通道的值为0或1, 1 表示该像素点是该类,0 则不是 seg_labels = np.reshape(seg_labels,(-1,n_classes)) # (208*208,2) Y.append(seg_labels) return np.array(X),np.array(Y) if __name__ == '__main__': # 1.参数初始化 opt = parse_opt() # 2.获取数据集 X,Y = get_data_from_file(opt) # 3.创建模型 # 每5个epoch保存一次 weight_path = 'weights/unet_' + opt.encoder_type+'_weight/' model = build_unet(opt.n_classes,opt.height,opt.width,opt.encoder_type,) os.makedirs(weight_path,exist_ok=True) checkpoint = ModelCheckpoint( filepath=weight_path+'acc{accuracy:.4f}-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5', monitor='val_loss', verbose=1,save_best_only=True,save_weights_only=True,period=5 ) lr_sh = ReduceLROnPlateau(monitor='val_loss',factor=0.5,patience=5,verbose=1) es = EarlyStopping(monitor='val_loss',patience=10,verbose=1) model.compile(loss=CategoricalCrossentropy(),optimizer=Adam(opt.lr),metrics='accuracy') # 4.模型训练 model.fit( x=X,y=Y, batch_size=opt.batch_size, epochs=opt.epochs, callbacks=[checkpoint,lr_sh,es], verbose=1, validation_split=0.1, shuffle=True, ) # 5.模型保存 model.save_weights(weight_path+'/last.h5')
from unet import build_unet from PIL import Image import numpy as np import copy import os import argparse def parse_opt(): parse = argparse.ArgumentParser() parse.add_argument('--test_imgs', type=str, default='test_imgs', help='测试数据集') parse.add_argument('--test_out', type=str, default='test_res', help='测试数据集') parse.add_argument('--n_classes', type=int, default=2, help='标签种类(含背景)') parse.add_argument('--height', type=int, default=416, help='输入模型的图片高度') parse.add_argument('--width', type=int, default=416, help='输入模型的图片宽度') parse.add_argument('--encoder_type', type=str, default='MobilenetV1_1', help='unet模型编码器的类型[MobilenetV1_1,MobilenetV1_2]') opt = parse.parse_args() return opt def resize_img(path,real_width,real_height): img_names = os.listdir(path) for img_name in img_names: img = Image.open(os.path.join(path, img_name)) img = img.resize((real_width,real_height)) img.save(os.path.join(path, img_name)) if __name__ == '__main__': # 1.参数初始化 opt = parse_opt() # class_colors 要根据图像的语义标签来设定;n_classes 行 3 列; # 3列为RGB的值 class_colors = [[0, 0, 0], [0, 255, 0]] imgs_path = os.listdir(opt.test_imgs) imgs_test = [] imgs_init = [] jpg_names = [] real_width,real_height = 1280,720 resize_img(opt.test_imgs, real_width,real_height) # 2.获取测试图片 for i,jpg_name in enumerate(imgs_path): img_init = Image.open(os.path.join(opt.test_imgs, jpg_name)) img = copy.deepcopy(img_init) img = img.resize((opt.width,opt.height)) img = np.array(img) / 255 # (416,416,3) [0,1] imgs_test.append(img) imgs_init.append(img_init) jpg_names.append(jpg_name) imgs_test = np.array(imgs_test) # (-1,416,416,3) # 3.模型创建 weight_path = 'weights/unet_' + opt.encoder_type + '_weight/' model = build_unet(opt.n_classes,opt.height,opt.width, opt.encoder_type) model.load_weights(os.path.join(weight_path, 'last.h5')) # 4.模型预测语义分类结果 prs = model.predict(imgs_test) # (-1, 43264, 2) # 结果 reshape prs = prs.reshape(-1, int(opt.height / 2), int(opt.width / 2), opt.n_classes) # (-1, 208, 208, 2) # 找到结果每个像素点所属类别的索引 两类就是 0 或 1 prs = prs.argmax(axis=-1) # (-1, 208, 208) # 此时 prs 就是预测出来的类别,argmax 求得是最大值所在的索引,这个索引和类别值相同 # 所以 prs 每个像素点就是对应的类别 # 5.创建语义图像 # 和训练集中的语义标签图像不同,这里要显示图像,所以固定3通道 imgs_seg = np.zeros((len(prs), int(opt.height / 2), int(opt.width / 2), 3)) # (-1,208,208,3) for c in range(opt.n_classes): # 每个通道都要判断是否属于第0,1,2... n-1 类,是的话就乘以对应的颜色,每个类别都要判断一次 # 因为是RGB三个通道,所以3个通道分别乘以class_colors的每个通道颜色值 imgs_seg[:,:,:,0] += ((prs[:,:,:]==c)*(class_colors[c][0])).astype(int) imgs_seg[:,:,:,1] += ((prs[:,:,:]==c)*(class_colors[c][1])).astype(int) imgs_seg[:,:,:,2] += ((prs[:,:,:]==c)*(class_colors[c][2])).astype(int) # 6.保存结果 save_path = opt.test_out+'/'+opt.encoder_type os.makedirs(save_path,exist_ok=True) for img_init,img_seg,img_name in zip(imgs_init,imgs_seg,jpg_names): img_seg = Image.fromarray(np.uint8(img_seg)).resize((real_width,real_height)) images = Image.blend(img_init,img_seg,0.3) images.save(os.path.join(opt.test_out+'/'+opt.encoder_type,img_name))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。