当前位置:   article > 正文

[深度学习从入门到女装]keras实战-UNet2d(VOC2012)_unet2dmodel

unet2dmodel

本文实现使用keras实现U-Net2d网络,对VOC2012进行分割

 

UNet2d

首先是Unet2d网络的搭建

  1. import keras.backend as K
  2. from keras.engine import Input,Model
  3. import keras
  4. from keras.optimizers import Adam
  5. from keras.layers import BatchNormalization,Activation,Conv2D,MaxPooling2D,Conv2DTranspose,UpSampling2D
  6. import metrics as m
  7. from keras.layers.core import Lambda
  8. import numpy as np
  9. def up_and_concate(down_layer, layer):
  10. in_channel = down_layer.get_shape().as_list()[3]
  11. out_channel = in_channel // 2
  12. up = Conv2DTranspose(out_channel,[2,2],strides=[2,2])(down_layer)
  13. print("--------------")
  14. print(str(up.get_shape()))
  15. print(str(layer.get_shape()))
  16. print("--------------")
  17. my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))
  18. concate = my_concat([up, layer])
  19. # must use lambda
  20. #concate=K.concatenate([up, layer], 3)
  21. return concate
  22. def attention_block_2d(x, g, inter_channel):
  23. '''
  24. :param x: x input from down_sampling same layer output (?,x_height,x_width,x_channel)
  25. :param g: gate input from up_sampling layer last output (?,g_height,g_width,g_channel)
  26. g_height,g_width=x_height/2,x_width/2
  27. :return:
  28. '''
  29. print('attention_block:')
  30. # theta_x(?,g_height,g_width,inter_channel)
  31. theta_x = Conv2D(inter_channel, [2, 2], strides=[2, 2])(x)
  32. print(str(theta_x.get_shape()))
  33. # phi_g(?,g_height,g_width,inter_channel)
  34. phi_g = Conv2D(inter_channel, [1, 1], strides=[1, 1])(g)
  35. print(str(phi_g.get_shape()))
  36. # f(?,g_height,g_width,inter_channel)
  37. f = Activation('relu')(keras.layers.add([theta_x, phi_g]))
  38. print(str(f.get_shape()))
  39. # psi_f(?,g_height,g_width,1)
  40. psi_f = Conv2D(1, [1, 1], strides=[1, 1])(f)
  41. print(str(psi_f.get_shape()))
  42. # sigm_psi_f(?,g_height,g_width)
  43. sigm_psi_f = Activation('sigmoid')(psi_f)
  44. print(str(sigm_psi_f.get_shape()))
  45. # rate(?,x_height,x_width)
  46. rate = UpSampling2D(size=[2, 2])(sigm_psi_f)
  47. print(str(rate.get_shape()))
  48. # att_x(?,x_height,x_width,x_channel)
  49. att_x = keras.layers.multiply([x, rate])
  50. print(str(att_x.get_shape()))
  51. print('-----------------')
  52. return att_x
  53. def unet_model_2d_attention(input_shape,n_labels,batch_normalization=False,initial_learning_rate=0.00001,metrics=m.dice_coefficient):
  54. """
  55. input_shape:without batch_size,(img_height,img_width,img_depth)
  56. metrics:
  57. """
  58. inputs=Input(input_shape)
  59. down_layer=[]
  60. layer=inputs
  61. #down_layer_1
  62. layer=res_block_v2(layer,64,batch_normalization=batch_normalization)
  63. down_layer.append(layer)
  64. layer=MaxPooling2D(pool_size=[2,2],strides=[2,2])(layer)
  65. print(str(layer.get_shape()))
  66. # down_layer_2
  67. layer = res_block_v2(layer, 128, batch_normalization=batch_normalization)
  68. down_layer.append(layer)
  69. layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
  70. print(str(layer.get_shape()))
  71. # down_layer_3
  72. layer = res_block_v2(layer, 256, batch_normalization=batch_normalization)
  73. down_layer.append(layer)
  74. layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
  75. print(str(layer.get_shape()))
  76. # down_layer_4
  77. layer = res_block_v2(layer, 512, batch_normalization=batch_normalization)
  78. down_layer.append(layer)
  79. layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
  80. print(str(layer.get_shape()))
  81. # bottle_layer
  82. layer = res_block_v2(layer, 1024, batch_normalization=batch_normalization)
  83. print(str(layer.get_shape()))
  84. # up_layer_4
  85. layer = attention_block_2d( down_layer[3],layer,256)
  86. layer = res_block_v2(layer, 512,batch_normalization=batch_normalization)
  87. print(str(layer.get_shape()))
  88. # up_layer_3
  89. layer = attention_block_2d( down_layer[2],layer,128)
  90. layer = res_block_v2(layer, 256, batch_normalization=batch_normalization)
  91. print(str(layer.get_shape()))
  92. # up_layer_2
  93. layer = attention_block_2d( down_layer[1],layer,64)
  94. layer = res_block_v2(layer, 128, batch_normalization=batch_normalization)
  95. print(str(layer.get_shape()))
  96. # up_layer_1
  97. layer = attention_block_2d( down_layer[0],layer,32)
  98. layer = res_block_v2(layer, 64, batch_normalization=batch_normalization)
  99. print(str(layer.get_shape()))
  100. # score_layer
  101. layer = Conv2D(n_labels,[1,1],strides=[1,1])(layer)
  102. print(str(layer.get_shape()))
  103. # softmax
  104. layer = Activation('softmax')(layer)
  105. print(str(layer.get_shape()))
  106. outputs=layer
  107. model=Model(inputs=inputs,outputs=outputs)
  108. metrics=[metrics]
  109. model.compile(optimizer=Adam(lr=initial_learning_rate), loss=m.dice_coefficient_loss, metrics=metrics)
  110. return model
  111. def unet_model_2d(input_shape,n_labels,batch_normalization=False,initial_learning_rate=0.00001,metrics=m.dice_coefficient):
  112. """
  113. input_shape:without batch_size,(img_height,img_width,img_depth)
  114. metrics:
  115. """
  116. inputs=Input(input_shape)
  117. down_layer=[]
  118. layer=inputs
  119. #down_layer_1
  120. layer=res_block_v2(layer,64,batch_normalization=batch_normalization)
  121. down_layer.append(layer)
  122. layer=MaxPooling2D(pool_size=[2,2],strides=[2,2])(layer)
  123. print(str(layer.get_shape()))
  124. # down_layer_2
  125. layer = res_block_v2(layer, 128, batch_normalization=batch_normalization)
  126. down_layer.append(layer)
  127. layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
  128. print(str(layer.get_shape()))
  129. # down_layer_3
  130. layer = res_block_v2(layer, 256, batch_normalization=batch_normalization)
  131. down_layer.append(layer)
  132. layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
  133. print(str(layer.get_shape()))
  134. # down_layer_4
  135. layer = res_block_v2(layer, 512, batch_normalization=batch_normalization)
  136. down_layer.append(layer)
  137. layer = MaxPooling2D(pool_size=[2, 2], strides=[2, 2])(layer)
  138. print(str(layer.get_shape()))
  139. # bottle_layer
  140. layer = res_block_v2(layer, 1024, batch_normalization=batch_normalization)
  141. print(str(layer.get_shape()))
  142. # up_layer_4
  143. layer = up_and_concate(layer, down_layer[3])
  144. layer = res_block_v2(layer, 512,batch_normalization=batch_normalization)
  145. print(str(layer.get_shape()))
  146. # up_layer_3
  147. layer = up_and_concate(layer, down_layer[2])
  148. layer = res_block_v2(layer, 256, batch_normalization=batch_normalization)
  149. print(str(layer.get_shape()))
  150. # up_layer_2
  151. layer = up_and_concate(layer, down_layer[1])
  152. layer = res_block_v2(layer, 128, batch_normalization=batch_normalization)
  153. print(str(layer.get_shape()))
  154. # up_layer_1
  155. layer = up_and_concate(layer, down_layer[0])
  156. layer = res_block_v2(layer, 64, batch_normalization=batch_normalization)
  157. print(str(layer.get_shape()))
  158. # score_layer
  159. layer = Conv2D(n_labels,[1,1],strides=[1,1])(layer)
  160. print(str(layer.get_shape()))
  161. # softmax
  162. layer = Activation('softmax')(layer)
  163. print(str(layer.get_shape()))
  164. outputs=layer
  165. model=Model(inputs=inputs,outputs=outputs)
  166. metrics=[metrics]
  167. model.compile(optimizer=Adam(lr=initial_learning_rate), loss=m.dice_coefficient_loss, metrics=metrics)
  168. return model
  169. def res_block_v2(input_layer,out_n_filters,batch_normalization=False,kernel_size=[3,3],stride=[1,1],padding='same'):
  170. input_n_filters = input_layer.get_shape().as_list()[3]
  171. print(str(input_layer.get_shape()))
  172. layer=input_layer
  173. for i in range(2):
  174. if batch_normalization:
  175. layer=BatchNormalization()(layer)
  176. layer=Activation('relu')(layer)
  177. layer=Conv2D(out_n_filters,kernel_size,strides=stride,padding=padding)(layer)
  178. if out_n_filters!=input_n_filters:
  179. skip_layer=Conv2D(out_n_filters,[1,1],strides=stride,padding=padding)(input_layer)
  180. else:
  181. skip_layer=input_layer
  182. out_layer=keras.layers.add([layer,skip_layer])
  183. return out_layer

使用Keras中的Model类,首先使用Input(input_shape),注意这里的input_shape是不带batch_size这一维的,在这里就是(img_height,img_width,img_depth)

metrics为评判标准

  1. def up_and_concate(down_layer, layer):
  2. in_channel = down_layer.get_shape().as_list()[3]
  3. out_channel = in_channel // 2
  4. up = Conv2DTranspose(out_channel,[2,2],strides=[2,2])(down_layer)
  5. print("--------------")
  6. print(str(up.get_shape()))
  7. print(str(layer.get_shape()))
  8. print("--------------")
  9. my_concat = Lambda(lambda x: K.concatenate([x[0], x[1]], axis=3))
  10. concate = my_concat([up, layer])
  11. # must use lambda
  12. #concate=K.concatenate([up, layer], 3)
  13. return concate

以上为skip_connection的函数,down_layer是上一层上采样层的输出,layer为同层下采样层的输出

注意这里不能直接用K.concatenate,会报错说使用的tensor不是keras里边的tensor,必须使用Lambda

  1. def res_block_v2(input_layer,out_n_filters,batch_normalization=False,kernel_size=[3,3],stride=[1,1],padding='same'):
  2. input_n_filters = input_layer.get_shape().as_list()[3]
  3. print(str(input_layer.get_shape()))
  4. layer=input_layer
  5. for i in range(2):
  6. if batch_normalization:
  7. layer=BatchNormalization()(layer)
  8. layer=Activation('relu')(layer)
  9. layer=Conv2D(out_n_filters,kernel_size,strides=stride,padding=padding)(layer)
  10. if out_n_filters!=input_n_filters:
  11. skip_layer=Conv2D(out_n_filters,[1,1],strides=stride,padding=padding)(input_layer)
  12. else:
  13. skip_layer=input_layer
  14. out_layer=keras.layers.add([layer,skip_layer])
  15. return out_layer

 

VOC读取

接下来是voc数据集的读取,voc数据集的目录如下:

Annotations为存放每个图片的描述文件(.xml),类别,检测框什么的,在分割上用不到

ImageSets为存放了各个任务的trian val所需图片的名称

JPEGImages为存放所有图片(.jpg)

SegmentationClass是用于语义分割的标签(.png)

SegmentationObject是用于实例分割的标签(.png)

在本文只用了ImageSets中的Segmentation和SegmentionClass中的文件

文件的读取如下:

  1. import tensorflow as tf
  2. from PIL import Image
  3. import PIL
  4. import scipy.misc as misc
  5. import numpy as np
  6. def make_one_hot(x,n):
  7. '''
  8. print(x.shape)
  9. one_hot=np.zeros([x.shape[0],x.shape[1],n])
  10. print(one_hot.shape)
  11. for i in range(n):
  12. #print(x==i)
  13. print(one_hot[x==i])
  14. one_hot[x==i][i]=1
  15. '''
  16. one_hot = np.zeros([x.shape[0], x.shape[1], n])
  17. for i in range(x.shape[0]):
  18. for j in range(x.shape[1]):
  19. one_hot[i,j,x[i,j]]=1
  20. return one_hot
  21. class voc_reader:
  22. def __init__(self,resize_width,resize_height,train_batch_size,val_batch_size):
  23. self.train_file_name_list=self.load_file_name_list(file_path="D:\\pyproject\\data\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\ImageSets\\Segmentation\\train.txt")
  24. self.val_file_name_list=self.load_file_name_list(file_path="D:\\pyproject\\data\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\ImageSets\\Segmentation\\val.txt")
  25. self.row_file_path="D:\\pyproject\\data\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\JPEGImages\\"
  26. self.label_file_path="D:\\pyproject\\data\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012\\SegmentationClass\\"
  27. self.train_batch_index=0
  28. self.val_batch_index=0
  29. self.resize_width=resize_width
  30. self.resize_height=resize_height
  31. self.n_train_file=len(self.train_file_name_list)
  32. self.n_val_file=len(self.val_file_name_list)
  33. self.train_batch_size=train_batch_size
  34. self.val_batch_size=val_batch_size
  35. print(self.n_train_file)
  36. print(self.n_val_file)
  37. self.n_train_steps_per_epoch=self.n_train_file//self.train_batch_size
  38. self.n_val_steps_per_epoch=self.n_val_file//self.val_batch_size
  39. def load_file_name_list(self,file_path):
  40. file_name_list=[]
  41. with open(file_path, 'r') as file_to_read:
  42. while True:
  43. lines = file_to_read.readline().strip() # 整行读取数据
  44. if not lines:
  45. break
  46. pass
  47. file_name_list.append(lines)
  48. pass
  49. return file_name_list
  50. def next_train_batch(self):
  51. train_imgs=np.zeros((self.train_batch_size,self.resize_height,self.resize_width,3))
  52. train_labels=np.zeros([self.train_batch_size,self.resize_height,self.resize_width,21])
  53. if self.train_batch_index>=self.n_train_steps_per_epoch:
  54. print("next epoch")
  55. self.train_batch_index=0
  56. print('------------------')
  57. print(self.train_batch_index)
  58. for i in range(self.train_batch_size):
  59. index=self.train_batch_size*self.train_batch_index+i
  60. print('index'+str(index))
  61. img = Image.open(self.row_file_path+self.train_file_name_list[index]+'.jpg')
  62. img=img.resize((self.resize_height,self.resize_width),Image.NEAREST)
  63. img=np.array(img)
  64. train_imgs[i]=img
  65. #print(img.shape)
  66. np.set_printoptions(threshold=np.inf)
  67. label=Image.open(self.label_file_path+self.train_file_name_list[index]+'.png')
  68. label=label.resize((self.resize_height,self.resize_width),Image.NEAREST)
  69. label=np.array(label, dtype=np.int32)
  70. #print(label[label>20])
  71. #label[label == 255] = -1
  72. label[label==255]=0
  73. #print(label)
  74. #print(label.shape)
  75. one_hot_label=make_one_hot(label,21)
  76. train_labels[i]=one_hot_label
  77. #print(one_hot_label.shape)
  78. #print(label)
  79. #print(label)
  80. self.train_batch_index+=1
  81. print('------------------')
  82. return train_imgs,train_labels
  83. def next_val_batch(self):
  84. val_imgs = np.zeros((self.val_batch_size, self.resize_height, self.resize_width, 3))
  85. val_labels = np.zeros([self.val_batch_size, self.resize_height, self.resize_width, 21])
  86. if self.val_batch_index>=self.n_val_steps_per_epoch:
  87. print("next epoch")
  88. self.val_batch_index=0
  89. print('------------------')
  90. print(self.val_batch_index)
  91. for i in range(self.val_batch_size):
  92. index=self.val_batch_size*self.val_batch_index+i
  93. print('index'+str(index))
  94. img=Image.open(self.row_file_path+self.val_file_name_list[index]+'.jpg')
  95. img = img.resize((self.resize_height, self.resize_width), Image.NEAREST)
  96. img = np.array(img)
  97. val_imgs[i]=img
  98. label = Image.open(self.label_file_path + self.val_file_name_list[index] + '.png')
  99. label = label.resize((self.resize_height, self.resize_width), Image.NEAREST)
  100. label = np.array(label, dtype=np.int32)
  101. # print(label[label>20])
  102. # label[label == 255] = -1
  103. label[label == 255] = 0
  104. # print(label)
  105. # print(label.shape)
  106. one_hot_label = make_one_hot(label, 21)
  107. val_labels[i]=one_hot_label
  108. print('------------------')
  109. self.val_batch_index+=1
  110. return val_imgs,val_labels

 

Train

先构造train和val的generator,用于等会的fit_generator

  1. def train_generator_data(batch_size,voc_reader):
  2. while True:
  3. x,y=voc_reader.next_train_batch(batch_size)
  4. yield (x,y)
  5. def val_generator_data(batch_size,voc_reader):
  6. while True:
  7. x,y=voc_reader.next_val_batch(batch_size)
  8. yield (x,y)

随后定义callback

  1. def get_callbacks(model_file,initial_learning_rate=0.0001,learning_rate_drop=0.5,learning_rate_epochs=None,
  2. learning_rate_patience=50,logging_file="training.log",verbosity=1,early_stopping_patience=None):
  3. callbacks=list()
  4. callbacks.append(ModelCheckpoint(model_file,save_best_only=True))
  5. callbacks.append(CSVLogger(logging_file,append=True))
  6. callbacks.append(TensorBoard())
  7. if learning_rate_epochs:
  8. callbacks.append(LearningRateScheduler(partial(step_decay, initial_lrate=initial_learning_rate,
  9. drop=learning_rate_drop, epochs_drop=learning_rate_epochs)))
  10. else:
  11. callbacks.append(ReduceLROnPlateau(factor=learning_rate_drop, patience=learning_rate_patience,
  12. verbose=verbosity))
  13. if early_stopping_patience:
  14. callbacks.append(EarlyStopping(verbose=verbosity, patience=early_stopping_patience))
  15. return callbacks

加载已经训练过的模型

  1. def load_old_model(model_file):
  2. print("Loading pre-trained model")
  3. custom_objects = {'dice_coefficient_loss': dice_coefficient_loss, 'dice_coefficient': dice_coefficient}
  4. try:
  5. from keras_contrib.layers import InstanceNormalization
  6. custom_objects["InstanceNormalization"] = InstanceNormalization
  7. except ImportError:
  8. pass
  9. try:
  10. return load_model(model_file, custom_objects=custom_objects)
  11. except ValueError as error:
  12. if 'InstanceNormalization' in str(error):
  13. raise ValueError(str(error) + "\n\nPlease install keras-contrib to use InstanceNormalization:\n"
  14. "'pip install git+https://www.github.com/keras-team/keras-contrib.git'")
  15. else:
  16. raise error

train_model函数

  1. def train_model(model, model_file, training_generator, validation_generator, steps_per_epoch, validation_steps,
  2. initial_learning_rate=0.001, learning_rate_drop=0.5, learning_rate_epochs=None, n_epochs=500,
  3. learning_rate_patience=20, early_stopping_patience=None):
  4. """
  5. Train a Keras model.
  6. :param early_stopping_patience: If set, training will end early if the validation loss does not improve after the
  7. specified number of epochs.
  8. :param learning_rate_patience: If learning_rate_epochs is not set, the learning rate will decrease if the validation
  9. loss does not improve after the specified number of epochs. (default is 20)
  10. :param model: Keras model that will be trained.
  11. :param model_file: Where to save the Keras model.
  12. :param training_generator: Generator that iterates through the training data.
  13. :param validation_generator: Generator that iterates through the validation data.
  14. :param steps_per_epoch: Number of batches that the training generator will provide during a given epoch.
  15. :param validation_steps: Number of batches that the validation generator will provide during a given epoch.
  16. :param initial_learning_rate: Learning rate at the beginning of training.
  17. :param learning_rate_drop: How much at which to the learning rate will decay.
  18. :param learning_rate_epochs: Number of epochs after which the learning rate will drop.
  19. :param n_epochs: Total number of epochs to train the model.
  20. :return:
  21. """
  22. model.fit_generator(generator=training_generator,
  23. steps_per_epoch=steps_per_epoch,
  24. epochs=n_epochs,
  25. validation_data=validation_generator,
  26. validation_steps=validation_steps,
  27. callbacks=get_callbacks(model_file,
  28. initial_learning_rate=initial_learning_rate,
  29. learning_rate_drop=learning_rate_drop,
  30. learning_rate_epochs=learning_rate_epochs,
  31. learning_rate_patience=learning_rate_patience,
  32. early_stopping_patience=early_stopping_patience))

 

github地址:https://github.com/panxiaobai/voc_keras

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号