- class YOLO(object):
- _defaults = {
- "model_path" : '.h5',
- "classes_path" : 'model_data/tomato.txt',
- "anchors_path" : 'model_data/yolo_anchors.txt',
- "anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
- "input_shape" : [416, 416],
- "confidence" : 0.2,
- "nms_iou" : 0.3,
- "max_boxes" : 100,
- "letterbox_image" : True,
- }
- @classmethod
- def get_defaults(cls, n):
- if n in cls._defaults:
- return cls._defaults[n]
- else:
- return "Unrecognized attribute name '" + n + "'"
- def __init__(self, **kwargs):
- self.__dict__.update(self._defaults)
- for name, value in kwargs.items():
- setattr(self, name, value)
- self.class_names, self.num_classes = get_classes(self.classes_path)
- self.anchors, self.num_anchors = get_anchors(self.anchors_path)
- hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
- self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
- self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
- self.generate()
- def generate(self):
- model_path = os.path.expanduser(self.model_path)
- assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
- self.yolo_model = yolo_body([None, None, 3], self.anchors_mask, self.num_classes)
- self.yolo_model.load_weights(self.model_path)
- print('{} model, anchors, and classes loaded.'.format(model_path))
- self.input_image_shape = Input([2,],batch_size=1)
- inputs = [*self.yolo_model.output, self.input_image_shape]
- outputs = Lambda(
- DecodeBox,
- output_shape = (1,),
- name = 'yolo_eval',
- arguments = {
- 'anchors' : self.anchors,
- 'num_classes' : self.num_classes,
- 'input_shape' : self.input_shape,
- 'anchor_mask' : self.anchors_mask,
- 'confidence' : self.confidence,
- 'nms_iou' : self.nms_iou,
- 'max_boxes' : self.max_boxes,
- 'letterbox_image' : self.letterbox_image
- }
- )(inputs)
- self.yolo_model = Model([self.yolo_model.input, self.input_image_shape], outputs)
- @tf.function
- def get_pred(self, image_data, input_image_shape):
- out_boxes, out_scores, out_classes = self.yolo_model([image_data, input_image_shape], training=False)
- return out_boxes, out_scores, out_classes
- def detect_image(self, image):
- image = cvtColor(image)
- image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
- image_data = np.expand_dims(preprocess_input(np.array(image_data, dtype='float32')), 0)
- input_image_shape = np.expand_dims(np.array([image.size[1], image.size[0]], dtype='float32'), 0)
- out_boxes, out_scores, out_classes = self.get_pred(image_data, input_image_shape)
- print('Found {} boxes for {}'.format(len(out_boxes), 'img'))
- font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
- thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
- for i, c in list(enumerate(out_classes)):
- predicted_class = self.class_names[int(c)]
- box = out_boxes[i]
- score = out_scores[i]
- top, left, bottom, right = box
- top = max(0, np.floor(top).astype('int32'))
- left = max(0, np.floor(left).astype('int32'))
- bottom = min(image.size[1], np.floor(bottom).astype('int32'))
- right = min(image.size[0], np.floor(right).astype('int32'))
- label = '{} {:.2f}'.format(predicted_class, score)
- draw = ImageDraw.Draw(image)
- label_size = draw.textsize(label, font)
- label = label.encode('utf-8')
- print(label, top, left, bottom, right)
- if top - label_size[1] >= 0:
- text_origin = np.array([left, top - label_size[1]])
- else:
- text_origin = np.array([left, top + 1])
- for i in range(thickness):
- draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
- draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
- draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
- del draw
- return image

def detect_image(self, image):
- if __name__ == "__main__":
- yolo = YOLO()
- image = yolo.detect_image(image)
- image.show()
- server = Flask(
- __name__,
- static_folder='../resources/received_images',
- )
- server.jinja_env.auto_reload = True
- server.config['TEMPLATES_AUTO_RELOAD'] = True
- <!DOCTYPE html>
- <html lang="en">
- <head>
- <meta charset="UTF-8">
- <title>前端页面</title>
- <link rel="icon" href="data:;base64,=">
- <script src="https://code.jquery.com/jquery-3.3.1.min.js"></script>
- <script>
- $(document).ready(function(){
- $("#image_file").change(function(){
- var file = $(this)[0].files[0];
- $("img#image_1").attr("src", URL.createObjectURL(file));
- });
- $("button#button_1").click(function(){
- var formData = new FormData($("#upload_form")[0]);
- $.ajax({
- url: "/get_drawedImage",
- type: 'POST',
- data: formData,
- processData: false,
- contentType: false,
- success: function(return_data){
- $("img#image_2").attr("src", return_data['src'])
- },
- error: function(return_data){
- alert("上传失败!")
- }
- })
- });
- });
- </script>
- </head>
- <body>
- <form id="upload_form" enctype="multipart/form-data">
- <input type="file" name="input_image" id="image_file"/>
- </form>
- <div>
- <p>原始图片<p>
- <img src="" id="image_1"/>
- </div>
- <div>
- <p>目标检测结果图<p>
- <img src="" id="image_2"/>
- </div>
- <button type="button" id="button_1">上传图片并检测</button>
- </body>
- </html>

- @server.route('/')
- def index():
- htmlFileName = '_10_yolov3_2.html'
- return render_template(htmlFileName)
- @server.route('/get_drawedImage', methods=['POST'])
- def anyname_you_like():
- startTime = time.time()
- received_file = request.files['input_image']
- imageFileName = received_file.filename
- if received_file:
- # 保存接收的图片到指定文件夹
- received_dirPath = '../resources/received_images'
- if not os.path.isdir(received_dirPath):
- os.makedirs(received_dirPath)
- imageFilePath = os.path.join(received_dirPath, imageFileName)
- received_file.save(imageFilePath)
- print('接收图片文件保存到此路径:%s' % imageFilePath)
- usedTime = time.time() - startTime
- print('接收图片并保存,总共耗时%.2f秒' % usedTime)
- # 对指定图片路径的图片做目标检测,并打印耗时
- image = Image.open(imageFilePath)
- yolo = YOLO()
- drawed_image = yolo.detect_image(image)
- # 把目标检测结果图保存在服务端指定路径,返回指定路径对应的图片经过base64编码后的字符串
- drawed_imageFileName = 'drawed_' + os.path.splitext(imageFileName)[0] + '.jpg'
- drawed_imageFilePath = os.path.join(received_dirPath, drawed_imageFileName)
- drawed_image.save(drawed_imageFilePath)
- image_source_url = url_for('static', filename=drawed_imageFileName)
- return jsonify(src=image_source_url)

- if __name__ == '__main__':
- server.run('', port=5000)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。