当前位置:   article > 正文

移动端unet人像分割模型--1_unet移植到手机

unet移植到手机

  个人对移动端神经网络开发一直饶有兴致。去年腾讯开源了NCNN框架之后,一直都在关注。近期成功利用别人训练好的mtcnn和mobilefacenet模型制作了一个ios版本人脸识别swift版本demo。希望maskrcnn移植到ncnn,在手机端实现一些有趣的应用。因为unet模型比较简单,干脆就从这个入手。

  基本的网络基于keras版本: https://github.com/TianzhongSong/Person-Segmentation-Keras

  不过keras没办法直接转成ncnn模型,研究过通过onnx模型做中间跳板,采用了一些开源的转换工具,也是一堆问题。NCNN支持几个神经网络训练框架:caffe/mxnet/pytorch,在ncnn的github有一篇issue里nihui推荐采用mxnet,因此mxnet也成为了我的首选。

  利用Person-Segmentation-Keras项目的数据集,同时基于https://github.com/milesial/Pytorch-UNet/tree/master/unet这个项目捣鼓了几段代码。训练完成,用来测试ncnn转换基本可用。

  转换过程发现许多问题,一个是调用ncnn extract会crash,经过调查,发现mxnet2ncnn工具也有bug,blob个数算错,其次是input层one_blob_only标志我的理解应该是false,不知道什么原因转换过来的模型这边是true,导致forward_layer函数里面bottoms变量访问异常。后来一层层extract出来打印输出的channel/width/height调查后又发现,我把unet.py里的name为pool5写成了pool4(文章中的code已经纠正),可能前面的crash跟这个致命错误有关系也说不定。只好重新训练模型,几个小时漫长等待,剩下部分下周再写。部分代码已经更新,请参考: https://github.com/xuduo35/unet_mxnet2ncnn

unetdataiter.py

  1. #!/usr/bin/env python
  2. # coding=utf8
  3. import os
  4. import sys
  5. import random
  6. import cv2
  7. import mxnet as mx
  8. import numpy as np
  9. from mxnet.io import DataIter, DataBatch
  10. sys.path.append('../')
  11. def get_batch(items, root_path, nClasses, height, width):
  12. x = []
  13. y = []
  14. for item in items:
  15. image_path = root_path + item.split(' ')[0]
  16. label_path = root_path + item.split(' ')[-1].strip()
  17. img = cv2.imread(image_path, 1)
  18. label_img = cv2.imread(label_path, 1)
  19. im = np.zeros((width, height, 3), dtype='uint8')
  20. im[:, :, :] = 128
  21. lim = np.zeros((width, height, 3), dtype='uint8')
  22. if img.shape[0] >= img.shape[1]:
  23. scale = img.shape[0] / height
  24. new_width = int(img.shape[1] / scale)
  25. diff = (width - new_width) // 2
  26. img = cv2.resize(img, (new_width, height))
  27. label_img = cv2.resize(label_img, (new_width, height))
  28. im[:, diff:diff + new_width, :] = img
  29. lim[:, diff:diff + new_width, :] = label_img
  30. else:
  31. scale = img.shape[1] / width
  32. new_height = int(img.shape[0] / scale)
  33. diff = (height - new_height) // 2
  34. img = cv2.resize(img, (width, new_height))
  35. label_img = cv2.resize(label_img, (width, new_height))
  36. im[diff:diff + new_height, :, :] = img
  37. lim[diff:diff + new_height, :, :] = label_img
  38. lim = lim[:, :, 0]
  39. seg_labels = np.zeros((height, width, nClasses))
  40. for c in range(nClasses):
  41. seg_labels[:, :, c] = (lim == c).astype(int)
  42. im = np.float32(im) / 127.5 - 1
  43. seg_labels = np.reshape(seg_labels, (width * height, nClasses))
  44. x.append(im.transpose((2,0,1)))
  45. y.append(seg_labels.transpose((1,0)))
  46. return mx.nd.array(x), mx.nd.array(y)
  47. class UnetDataIter(mx.io.DataIter):
  48. def __init__(self, root_path, path_file, batch_size, n_classes, input_width, input_height, train=True):
  49. f = open(path_file, 'r')
  50. self.items = f.readlines()
  51. f.close()
  52. self._provide_data = [['data', (batch_size, 3, input_width, input_height)]]
  53. self._provide_label = [['softmax_label', (batch_size, n_classes, input_width*input_height)]]
  54. self.root_path = root_path
  55. self.batch_size = batch_size
  56. self.num_batches = len(self.items) // batch_size
  57. self.n_classes = n_classes
  58. self.input_height = input_height
  59. self.input_width = input_width
  60. self.train = train
  61. self.reset()
  62. def __iter__(self):
  63. return self
  64. def reset(self):
  65. self.cur_batch = 0
  66. self.shuffled_items = []
  67. index = [n for n in range(len(self.items))]
  68. if self.train:
  69. random.shuffle(index)
  70. for i in range(len(self.items)):
  71. self.shuffled_items.append(self.items[index[i]])
  72. def __next__(self):
  73. return self.next()
  74. @property
  75. def provide_data(self):
  76. return self._provide_data
  77. @property
  78. def provide_label(self):
  79. return self._provide_label
  80. def next(self):
  81. if self.cur_batch == 0:
  82. print("")
  83. print("\r\033[k"+("Training " if self.train else "Validating ")+str(self.cur_batch)+"/"+str(self.num_batches), end=' ')
  84. if self.cur_batch < self.num_batches:
  85. data, label = get_batch(self.shuffled_items[self.cur_batch * self.batch_size:(self.cur_batch + 1) * self.batch_size], self.root_path, self.n_classes, self.input_height, self.input_width)
  86. self.cur_batch += 1
  87. return mx.io.DataBatch([data], [label])
  88. else:
  89. raise StopIteration
  90. if __name__ =='__main__':
  91. root_path = '/datasets/'
  92. train_file = './data/seg_train.txt'
  93. val_file = './data/seg_test.txt'
  94. batch_size = 16
  95. n_classes = 2
  96. img_width = 256
  97. img_height = 256
  98. trainiter = UnetDataIter(root_path, train_file, batch_size, n_classes, img_width, img_height, True)
  99. while True:
  100. trainiter.next()

unet.py

  1. import os
  2. os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
  3. os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
  4. import mxnet as mx
  5. from mxnet import ndarray as F
  6. from skimage.transform import resize
  7. from skimage.io import imsave
  8. import numpy as np
  9. from unetdataiter import UnetDataIter
  10. import matplotlib.pyplot as plt
  11. def dice_coef(y_true, y_pred):
  12. intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=(1, 2, 3))
  13. return mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.sum(y_true, axis=(1, 2, 3)) + mx.sym.sum(y_pred, axis=(1, 2, 3)) + 1.))
  14. def dice_coef_loss(y_true, y_pred):
  15. intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=1, )
  16. return -mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.broadcast_add(mx.sym.sum(y_true, axis=1), mx.sym.sum(y_pred, axis=1)) + 1.))
  17. def build_unet(batch_size, input_width, input_height, train=True):
  18. data = mx.sym.Variable(name='data')
  19. label = mx.sym.Variable(name='softmax_label')
  20. # encode
  21. # 256x256
  22. conv1 = mx.sym.Convolution(data, num_filter=64, kernel=(3,3), pad=(1,1), name='conv1_1')
  23. conv1 = mx.sym.BatchNorm(conv1, name='bn1_1')
  24. conv1 = mx.sym.Activation(conv1, act_type='relu', name='relu1_1')
  25. conv1 = mx.sym.Convolution(conv1, num_filter=64, kernel=(3,3), pad=(1,1), name='conv1_2')
  26. conv1 = mx.sym.BatchNorm(conv1, name='bn1_2')
  27. conv1 = mx.sym.Activation(conv1, act_type='relu', name='relu1_2')
  28. pool1 = mx.sym.Pooling(conv1, kernel=(2,2), pool_type='max', name='pool1')
  29. # 128x128
  30. conv2 = mx.sym.Convolution(pool1, num_filter=128, kernel=(3,3), pad=(1,1), name='conv2_1')
  31. conv2 = mx.sym.BatchNorm(conv2, name='bn2_1')
  32. conv2 = mx.sym.Activation(conv2, act_type='relu', name='relu2_1')
  33. conv2 = mx.sym.Convolution(conv2, num_filter=128, kernel=(3,3), pad=(1,1), name='conv2_2')
  34. conv2 = mx.sym.BatchNorm(conv2, name='bn2_2')
  35. conv2 = mx.sym.Activation(conv2, act_type='relu', name='relu2_2')
  36. pool2 = mx.sym.Pooling(conv2, kernel=(2,2), pool_type='max', name='pool2')
  37. # 64x64
  38. conv3 = mx.sym.Convolution(pool2, num_filter=256, kernel=(3,3), pad=(1,1), name='conv3_1')
  39. conv3 = mx.sym.BatchNorm(conv3, name='bn3_1')
  40. conv3 = mx.sym.Activation(conv3, act_type='relu', name='relu3_1')
  41. conv3 = mx.sym.Convolution(conv3, num_filter=256, kernel=(3,3), pad=(1,1), name='conv3_2')
  42. conv3 = mx.sym.BatchNorm(conv3, name='bn3_2')
  43. conv3 = mx.sym.Activation(conv3, act_type='relu', name='relu3_2')
  44. pool3 = mx.sym.Pooling(conv3, kernel=(2,2), pool_type='max', name='pool3')
  45. # 32x32
  46. conv4 = mx.sym.Convolution(pool3, num_filter=256, kernel=(3,3), pad=(1,1), name='conv4_1')
  47. conv4 = mx.sym.BatchNorm(conv4, name='bn4_1')
  48. conv4 = mx.sym.Activation(conv4, act_type='relu', name='relu4_1')
  49. conv4 = mx.sym.Convolution(conv4, num_filter=256, kernel=(3,3), pad=(1,1), name='conv4_2')
  50. conv4 = mx.sym.BatchNorm(conv4, name='bn4_2')
  51. conv4 = mx.sym.Activation(conv4, act_type='relu', name='relu4_2')
  52. pool4 = mx.sym.Pooling(conv4, kernel=(2,2), pool_type='max', name='pool4')
  53. # 16x16
  54. conv5 = mx.sym.Convolution(pool4, num_filter=256, kernel=(3,3), pad=(1,1), name='conv5_1')
  55. conv5 = mx.sym.BatchNorm(conv5, name='bn5_1')
  56. conv5 = mx.sym.Activation(conv5, act_type='relu', name='relu5_1')
  57. conv5 = mx.sym.Convolution(conv5, num_filter=256, kernel=(3,3), pad=(1,1), name='conv5_2')
  58. conv5 = mx.sym.BatchNorm(conv5, name='bn5_2')
  59. conv5 = mx.sym.Activation(conv5, act_type='relu', name='relu5_2')
  60. pool5 = mx.sym.Pooling(conv5, kernel=(2,2), pool_type='max', name='pool5')
  61. # 8x8
  62. # decode
  63. trans_conv6 = mx.sym.Deconvolution(pool5, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv6')
  64. up6 = mx.sym.concat(*[trans_conv6, conv5], dim=1, name='concat6')
  65. conv6 = mx.sym.Convolution(up6, num_filter=256, kernel=(3,3), pad=(1,1), name='conv6_1')
  66. conv6 = mx.sym.BatchNorm(conv6, name='bn6_1')
  67. conv6 = mx.sym.Activation(conv6, act_type='relu', name='relu6_1')
  68. conv6 = mx.sym.Convolution(conv6, num_filter=256, kernel=(3,3), pad=(1,1), name='conv6_2')
  69. conv6 = mx.sym.BatchNorm(conv6, name='bn6_2')
  70. conv6 = mx.sym.Activation(conv6, act_type='relu', name='relu6_2')
  71. trans_conv7 = mx.sym.Deconvolution(conv6, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv7')
  72. up7 = mx.sym.concat(*[trans_conv7, conv4], dim=1, name='concat7')
  73. conv7 = mx.sym.Convolution(up7, num_filter=256, kernel=(3,3), pad=(1,1), name='conv7_1')
  74. conv7 = mx.sym.BatchNorm(conv7, name='bn7_1')
  75. conv7 = mx.sym.Activation(conv7, act_type='relu', name='relu7_1')
  76. conv7 = mx.sym.Convolution(conv7, num_filter=256, kernel=(3,3), pad=(1,1), name='conv7_2')
  77. conv7 = mx.sym.BatchNorm(conv7, name='bn7_2')
  78. conv7 = mx.sym.Activation(conv7, act_type='relu', name='relu7_2')
  79. trans_conv8 = mx.sym.Deconvolution(conv7, num_filter=256, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv8')
  80. up8 = mx.sym.concat(*[trans_conv8, conv3], dim=1, name='concat8')
  81. conv8 = mx.sym.Convolution(up8, num_filter=256, kernel=(3,3), pad=(1,1), name='conv8_1')
  82. conv8 = mx.sym.BatchNorm(conv8, name='bn8_1')
  83. conv8 = mx.sym.Activation(conv8, act_type='relu', name='relu8_1')
  84. conv8 = mx.sym.Convolution(conv8, num_filter=256, kernel=(3,3), pad=(1,1), name='conv8_2')
  85. conv8 = mx.sym.BatchNorm(conv8, name='bn8_2')
  86. conv8 = mx.sym.Activation(conv8, act_type='relu', name='relu8_2')
  87. trans_conv9 = mx.sym.Deconvolution(conv8, num_filter=128, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv9')
  88. up9 = mx.sym.concat(*[trans_conv9, conv2], dim=1, name='concat9')
  89. conv9 = mx.sym.Convolution(up9, num_filter=128, kernel=(3,3), pad=(1,1), name='conv9_1')
  90. conv9 = mx.sym.BatchNorm(conv9, name='bn9_1')
  91. conv9 = mx.sym.Activation(conv9, act_type='relu', name='relu9_1')
  92. conv9 = mx.sym.Convolution(conv9, num_filter=128, kernel=(3,3), pad=(1,1), name='conv9_2')
  93. conv9 = mx.sym.BatchNorm(conv9, name='bn9_2')
  94. conv9 = mx.sym.Activation(conv9, act_type='relu', name='relu9_2')
  95. trans_conv10 = mx.sym.Deconvolution(conv9, num_filter=64, kernel=(2,2), stride=(1,1), no_bias=True, name='trans_conv10')
  96. up10 = mx.sym.concat(*[trans_conv10, conv1], dim=1, name='concat10')
  97. conv10 = mx.sym.Convolution(up10, num_filter=64, kernel=(3,3), pad=(1,1), name='conv10_1')
  98. conv10 = mx.sym.BatchNorm(conv10, name='bn10_1')
  99. conv10 = mx.sym.Activation(conv10, act_type='relu', name='relu10_1')
  100. conv10 = mx.sym.Convolution(conv10, num_filter=64, kernel=(3,3), pad=(1,1), name='conv10_2')
  101. conv10 = mx.sym.BatchNorm(conv10, name='bn10_2')
  102. conv10 = mx.sym.Activation(conv10, act_type='relu', name='relu10_2')
  103. ###
  104. conv11 = mx.sym.Convolution(conv10, num_filter=2, kernel=(1,1), name='conv11_1')
  105. conv11 = mx.sym.sigmoid(conv11, name='softmax')
  106. net = mx.sym.Reshape(conv11, (batch_size, 2, input_width*input_height))
  107. if train:
  108. loss = mx.sym.MakeLoss(dice_coef_loss(label, net), normalization='batch')
  109. mask_output = mx.sym.BlockGrad(conv11, 'mask')
  110. out = mx.sym.Group([loss, mask_output])
  111. else:
  112. # mask_output = mx.sym.BlockGrad(conv11, 'mask')
  113. out = mx.sym.Group([conv11])
  114. return out

trainunet.py

  1. import os
  2. os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
  3. os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
  4. import mxnet as mx
  5. from mxnet import ndarray as F
  6. from skimage.transform import resize
  7. from skimage.io import imsave
  8. import numpy as np
  9. from unetdataiter import UnetDataIter
  10. import matplotlib.pyplot as plt
  11. from unet import build_unet
  12. def main():
  13. root_path = '../datasets/'
  14. train_file = './data/seg_train.txt'
  15. val_file = './data/seg_test.txt'
  16. batch_size = 16
  17. n_classes = 2
  18. # img_width = 256
  19. # img_height = 256
  20. img_width = 96
  21. img_height = 96
  22. train_iter = UnetDataIter(root_path, train_file, batch_size, n_classes, img_width, img_height, True)
  23. val_iter = UnetDataIter(root_path, val_file, batch_size, n_classes, img_width, img_height, False)
  24. ctx = [mx.gpu(0)]
  25. unet_sym = build_unet(batch_size, img_width, img_height)
  26. unet = mx.mod.Module(unet_sym, context=ctx, data_names=('data',), label_names=('softmax_label',))
  27. unet.bind(data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=[['softmax_label', (batch_size, n_classes, img_width*img_height)]])
  28. unet.init_params(mx.initializer.Xavier(magnitude=6))
  29. unet.init_optimizer(optimizer = 'adam',
  30. optimizer_params=(
  31. ('learning_rate', 1E-4),
  32. ('beta1', 0.9),
  33. ('beta2', 0.99)
  34. ))
  35. # unet.fit(train_iter, # train data
  36. # eval_data=val_iter, # validation data
  37. # #optimizer='sgd', # use SGD to train
  38. # #optimizer_params={'learning_rate':0.1}, # use fixed learning rate
  39. # eval_metric='acc', # report accuracy during training
  40. # batch_end_callback = mx.callback.Speedometer(batch_size, 1), # output progress for each 100 data batches
  41. # num_epoch=10) # train for at most 10 dataset passes
  42. epochs = 20
  43. smoothing_constant = .01
  44. curr_losses = []
  45. moving_losses = []
  46. i = 0
  47. best_val_loss = np.inf
  48. for e in range(epochs):
  49. while True:
  50. try:
  51. batch = next(train_iter)
  52. except StopIteration:
  53. train_iter.reset()
  54. break
  55. unet.forward_backward(batch)
  56. loss = unet.get_outputs()[0]
  57. unet.update()
  58. curr_loss = F.mean(loss).asscalar()
  59. curr_losses.append(curr_loss)
  60. moving_loss = (curr_loss if ((i == 0) and (e == 0))
  61. else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss)
  62. moving_losses.append(moving_loss)
  63. i += 1
  64. val_losses = []
  65. for batch in val_iter:
  66. unet.forward(batch)
  67. loss = unet.get_outputs()[0]
  68. val_losses.append(F.mean(loss).asscalar())
  69. val_iter.reset()
  70. val_loss = np.mean(val_losses)
  71. print("\nEpoch %i: Moving Training Loss %0.5f, Validation Loss %0.5f" % (e, moving_loss, val_loss))
  72. unet.save_checkpoint('./unet_person_segmentation', e)
  73. if __name__ =='__main__':
  74. main()

  以上是训练代码。

  预测代码如下predict.py

  1. import os
  2. os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
  3. os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
  4. import sys
  5. import cv2
  6. import mxnet as mx
  7. from mxnet import ndarray as F
  8. from skimage.transform import resize
  9. from skimage.io import imsave
  10. import numpy as np
  11. from unetdataiter import UnetDataIter
  12. import matplotlib.pyplot as plt
  13. from unet import build_unet
  14. def post_process_mask(label, img_cols, img_rows, n_classes, p=0.5):
  15. pr = label.reshape(n_classes, img_cols, img_rows).transpose([1,2,0]).argmax(axis=2)
  16. return (pr*255).asnumpy()
  17. def load_image(img, width, height):
  18. im = np.zeros((height, width, 3), dtype='uint8')
  19. im[:, :, :] = 128
  20. if img.shape[0] >= img.shape[1]:
  21. scale = img.shape[0] / height
  22. new_width = int(img.shape[1] / scale)
  23. diff = (width - new_width) // 2
  24. img = cv2.resize(img, (new_width, height))
  25. im[:, diff:diff + new_width, :] = img
  26. else:
  27. scale = img.shape[1] / width
  28. new_height = int(img.shape[0] / scale)
  29. diff = (height - new_height) // 2
  30. img = cv2.resize(img, (width, new_height))
  31. im[diff:diff + new_height, :, :] = img
  32. im = np.float32(im) / 127.5 - 1
  33. return [im.transpose((2,0,1))]
  34. def main():
  35. batch_size = 16
  36. n_classes = 2
  37. # img_width = 256
  38. # img_height = 256
  39. img_width = 96
  40. img_height = 96
  41. ctx = [mx.gpu(0)]
  42. # sym, arg_params, aux_params = mx.model.load_checkpoint('unet_person_segmentation', 20)
  43. # unet_sym = build_unet(batch_size, img_width, img_height, False)
  44. # unet = mx.mod.Module(symbol=unet_sym, context=ctx, label_names=None)
  45. sym, arg_params, aux_params = mx.model.load_checkpoint('unet_person_segmentation', 0)
  46. unet = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
  47. unet.bind(for_training=False, data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=unet._label_shapes)
  48. unet.set_params(arg_params, aux_params, allow_missing=True)
  49. testimg = cv2.imread(sys.argv[1], 1)
  50. img = load_image(testimg, img_width, img_height)
  51. unet.predict(mx.io.NDArrayIter(data=[img]))
  52. outputs = unet.get_outputs()[0]
  53. cv2.imshow('test', testimg)
  54. cv2.imshow('mask', post_process_mask(outputs[0], img_width, img_height, n_classes))
  55. cv2.waitKey()
  56. if __name__ == '__main__':
  57. if len(sys.argv) < 2:
  58. print("illegal parameters")
  59. sys.exit(0)
  60. main()

  剥离softmax保存参数用于ncnn模型转换,train2infer.py

  1. import os
  2. os.environ["MXNET_BACKWARD_DO_MIRROR"] = "1"
  3. os.environ["MXNET_CUDNN_AUTOTUNE_DEFAULT"] = "0"
  4. import sys
  5. import cv2
  6. import mxnet as mx
  7. from mxnet import ndarray as F
  8. from skimage.transform import resize
  9. from skimage.io import imsave
  10. import numpy as np
  11. from unetdataiter import UnetDataIter
  12. import matplotlib.pyplot as plt
  13. from unet import build_unet
  14. def main():
  15. batch_size = 16
  16. n_classes = 2
  17. # img_width = 256
  18. # img_height = 256
  19. img_width = 96
  20. img_height = 96
  21. ctx = [mx.gpu(0)]
  22. sym, arg_params, aux_params = mx.model.load_checkpoint(sys.argv[1], int(sys.argv[2]))
  23. unet_sym = build_unet(batch_size, img_width, img_height, False)
  24. unet = mx.mod.Module(symbol=unet_sym, context=ctx, label_names=None)
  25. unet.bind(for_training=False, data_shapes=[['data', (batch_size, 3, img_width, img_height)]], label_shapes=unet._label_shapes)
  26. unet.set_params(arg_params, aux_params, allow_missing=True)
  27. unet.save_checkpoint('./unet_person_segmentation', 0)
  28. if __name__ == '__main__':
  29. if len(sys.argv) < 3:
  30. print("illegal parameters")
  31. sys.exit(0)
  32. main()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/284976
推荐阅读
相关标签
  

闽ICP备14008679号