当前位置:   article > 正文

深度学习YOLOv4模型简单部署,flask框架搭建以及web开发_self.font = imagefont.truetype(font='simhei.ttf',

self.font = imagefont.truetype(font='simhei.ttf', size=np.floor(3e-2 * self.

一.深度学习模型的部署

相信很多人在训练深度学习模型之后不知道如何将训练好的模型部署到服务端,这里我们提出一个想法,用Flask来搭建一个框架,Flask就不在这里进行详细讲解了,本文的模型皆以Tensorflow2.X为基础。我们这里就着重于Flask的搭建。服务器和域名我们就不多讲解。

二.本地预测代码部分

首先我们需要有一个训练好的模型的权重.h5文件,首先我们需要实现的就是在本地能够检测图片的功能。

  1. class YOLO(object):
  2. _defaults = {
  3. "model_path" : '.h5',
  4. "classes_path" : 'model_data/tomato.txt',
  5. "anchors_path" : 'model_data/yolo_anchors.txt',
  6. "anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
  7. "input_shape" : [416, 416],
  8. "confidence" : 0.2,
  9. "nms_iou" : 0.3,
  10. "max_boxes" : 100,
  11. "letterbox_image" : True,
  12. }
  13. @classmethod
  14. def get_defaults(cls, n):
  15. if n in cls._defaults:
  16. return cls._defaults[n]
  17. else:
  18. return "Unrecognized attribute name '" + n + "'"
  19. def __init__(self, **kwargs):
  20. self.__dict__.update(self._defaults)
  21. for name, value in kwargs.items():
  22. setattr(self, name, value)
  23. self.class_names, self.num_classes = get_classes(self.classes_path)
  24. self.anchors, self.num_anchors = get_anchors(self.anchors_path)
  25. hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
  26. self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
  27. self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
  28. self.generate()
  29. def generate(self):
  30. model_path = os.path.expanduser(self.model_path)
  31. assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
  32. self.yolo_model = yolo_body([None, None, 3], self.anchors_mask, self.num_classes)
  33. self.yolo_model.load_weights(self.model_path)
  34. print('{} model, anchors, and classes loaded.'.format(model_path))
  35. self.input_image_shape = Input([2,],batch_size=1)
  36. inputs = [*self.yolo_model.output, self.input_image_shape]
  37. outputs = Lambda(
  38. DecodeBox,
  39. output_shape = (1,),
  40. name = 'yolo_eval',
  41. arguments = {
  42. 'anchors' : self.anchors,
  43. 'num_classes' : self.num_classes,
  44. 'input_shape' : self.input_shape,
  45. 'anchor_mask' : self.anchors_mask,
  46. 'confidence' : self.confidence,
  47. 'nms_iou' : self.nms_iou,
  48. 'max_boxes' : self.max_boxes,
  49. 'letterbox_image' : self.letterbox_image
  50. }
  51. )(inputs)
  52. self.yolo_model = Model([self.yolo_model.input, self.input_image_shape], outputs)
  53. @tf.function
  54. def get_pred(self, image_data, input_image_shape):
  55. out_boxes, out_scores, out_classes = self.yolo_model([image_data, input_image_shape], training=False)
  56. return out_boxes, out_scores, out_classes
  57. def detect_image(self, image):
  58. image = cvtColor(image)
  59. image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
  60. image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0)
  61. input_image_shape = np.expand_dims(np.array([image.size[1], image.size[0]], dtype='float32'), 0)
  62. out_boxes, out_scores, out_classes = self.get_pred(image_data, input_image_shape)
  63. print('Found {} boxes for {}'.format(len(out_boxes), 'img'))
  64. font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
  65. thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
  66. for i, c in list(enumerate(out_classes)):
  67. predicted_class = self.class_names[int(c)]
  68. box = out_boxes[i]
  69. score = out_scores[i]
  70. top, left, bottom, right = box
  71. top = max(0, np.floor(top).astype('int32'))
  72. left = max(0, np.floor(left).astype('int32'))
  73. bottom = min(image.size[1], np.floor(bottom).astype('int32'))
  74. right = min(image.size[0], np.floor(right).astype('int32'))
  75. label = '{} {:.2f}'.format(predicted_class, score)
  76. draw = ImageDraw.Draw(image)
  77. label_size = draw.textsize(label, font)
  78. label = label.encode('utf-8')
  79. print(label, top, left, bottom, right)
  80. if top - label_size[1] >= 0:
  81. text_origin = np.array([left, top - label_size[1]])
  82. else:
  83. text_origin = np.array([left, top + 1])
  84. for i in range(thickness):
  85. draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
  86. draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
  87. draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
  88. del draw
  89. return image

def detect_image(self, image): 

就是我们检测图片的一个函数,传入需要检测的image,return的便是已经框选好的图片了。

那么在本地想要进行预测,只需要以下代码

  1. if __name__ == "__main__":
  2. yolo = YOLO()
  3. image = yolo.detect_image(image)
  4. image.show()

三.搭建Flask框架代码部分

实例化Flask服务对象,赋值给变量server,设置开启web服务后,如果更新html文件,可以使得更新立即生效。

  1. server = Flask(
  2. __name__,
  3. static_folder='../resources/received_images',
  4. )
  5. server.jinja_env.auto_reload = True
  6. server.config['TEMPLATES_AUTO_RELOAD'] = True

然后我们需要写一个html后缀文件,使得我们的有一个前端页面进行交互 

  1. <!DOCTYPE html>
  2. <html lang="en">
  3. <head>
  4. <meta charset="UTF-8">
  5. <title>前端页面</title>
  6. <link rel="icon" href="data:;base64,=">
  7. <script src="https://code.jquery.com/jquery-3.3.1.min.js"></script>
  8. <script>
  9. $(document).ready(function(){
  10. $("#image_file").change(function(){
  11. var file = $(this)[0].files[0];
  12. $("img#image_1").attr("src", URL.createObjectURL(file));
  13. });
  14. $("button#button_1").click(function(){
  15. var formData = new FormData($("#upload_form")[0]);
  16. $.ajax({
  17. url: "/get_drawedImage",
  18. type: 'POST',
  19. data: formData,
  20. processData: false,
  21. contentType: false,
  22. success: function(return_data){
  23. $("img#image_2").attr("src", return_data['src'])
  24. },
  25. error: function(return_data){
  26. alert("上传失败!")
  27. }
  28. })
  29. });
  30. });
  31. </script>
  32. </head>
  33. <body>
  34. <form id="upload_form" enctype="multipart/form-data">
  35. <input type="file" name="input_image" id="image_file"/>
  36. </form>
  37. <div>
  38. <p>原始图片<p>
  39. <img src="" id="image_1"/>
  40. </div>
  41. <div>
  42. <p>目标检测结果图<p>
  43. <img src="" id="image_2"/>
  44. </div>
  45. <button type="button" id="button_1">上传图片并检测</button>
  46. </body>
  47. </html>

 随后在部署代码里写入调用html的回调函数

  1. @server.route('/')
  2. def index():
  3. htmlFileName = '_10_yolov3_2.html'
  4. return render_template(htmlFileName)

部署代码中的检测图像的回调函数 

  1. @server.route('/get_drawedImage', methods=['POST'])
  2. def anyname_you_like():
  3. startTime = time.time()
  4. received_file = request.files['input_image']
  5. imageFileName = received_file.filename
  6. if received_file:
  7. # 保存接收的图片到指定文件夹
  8. received_dirPath = '../resources/received_images'
  9. if not os.path.isdir(received_dirPath):
  10. os.makedirs(received_dirPath)
  11. imageFilePath = os.path.join(received_dirPath, imageFileName)
  12. received_file.save(imageFilePath)
  13. print('接收图片文件保存到此路径:%s' % imageFilePath)
  14. usedTime = time.time() - startTime
  15. print('接收图片并保存,总共耗时%.2f秒' % usedTime)
  16. # 对指定图片路径的图片做目标检测,并打印耗时
  17. image = Image.open(imageFilePath)
  18. yolo = YOLO()
  19. drawed_image = yolo.detect_image(image)
  20. # 把目标检测结果图保存在服务端指定路径,返回指定路径对应的图片经过base64编码后的字符串
  21. drawed_imageFileName = 'drawed_' + os.path.splitext(imageFileName)[0] + '.jpg'
  22. drawed_imageFilePath = os.path.join(received_dirPath, drawed_imageFileName)
  23. drawed_image.save(drawed_imageFilePath)
  24. image_source_url = url_for('static', filename=drawed_imageFileName)
  25. return jsonify(src=image_source_url)

部署代码的主函数

  1. if __name__ == '__main__':
  2. server.run('127.0.0.1', port=5000)

运行之后,打开网站,我们可以得到以下结果。 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/83249
推荐阅读
  

闽ICP备14008679号