当前位置:   article > 正文

使用Flask部署YoloV3-PyTorch_flask yolov3

flask yolov3

使用Flask部署YoloV3-PyTorch

一、项目简介

这个项目是一个web对象检测的小demo,使用Yolov3(PyTorch) 和 Flask 在 Web 端进行对象检测,涉及目标检测、Flask和Html
Yolov3 来自 Ultralytics,你可以可以使用他们的项目来训练一个满足自己的模型

二. 项目整体框架与代码

项目地址:https://github.com/BonesCat/Yolov3_flask
在这里插入图片描述
主要是在Yolov3-Ultralytics的代码上进行修改,具体如下:

  • 1.将原detect.py修改为detect_for_flask.py,为Flask提供一个接
  • 2.所有上传的文件将被时间重命名并保存到“upload_files”文件夹
  • 3.检测到的图像将被保存到“输出”文件夹中

三、快速开始

  • 按照 ult-yolov3 中requirement要求配置环境,自行安装Flask,注意都需要在一个evn环境中进行安装与配置
  • 下载或训练一个模型,将“.weights/.pt”文件放到weights文件夹,配置正确的cfg,其他配置可以在opt上设置.本项目可以使用原始yolov3提供的官方权重,只需设置对应cfg即可。
  • 启动serve.py,然后在网站上输入“http://127.0.0.1:2222/upload”,上传图片,即可得到结果和检测信息。

四、 核心部分代码与简单讲解

  • Server.py
import time
import os
# 导入flask库中的Flask类与request对象
from flask import Flask, request, flash, redirect, render_template, jsonify
from datetime import timedelta

# 导入模型相关函数
from detect_for_flask import *


app = Flask(__name__)

# 设置上传文件的保存位置
UPLOAD_FOLDER = 'upload_files'
ALLOWED_EXTENSIONS = {'pdf', 'png', 'jpg', 'jpeg', 'gif'}

# 配置路径到app
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER

# 设置静态文件缓存过期时间
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = timedelta(seconds=5) # timedalte 是datetime中的一个对象,该对象表示两个时间的差值

print("SEND_FILE_MAX_AGE_DEFAULT:", app.config['SEND_FILE_MAX_AGE_DEFAULT'])

# 预先初始化模型
model_inited, opt = init_model()

# 处理文件名的有效性
def allow_filename(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/upload', methods=['GET', 'POST']) # 添加路由

def upload():
    if request.method == 'POST':
        # 如果上传的file不是在files
        if 'file' not in request.files:
            # Flask 消息闪现
            flash('not file part!')
            # 重新显示当前url页面
            return  redirect(request.url)

        '''
        Flask 框架中的 request 对象保存了一次HTTP请求的一切信息。
        files 记录了请求上传的文件
        '''
        f = request.files['file']

        # 处理空文件
        if f.filename == '':
            flash("Nothing file upload")
            return redirect(request.url)

        # 文件非空,且格式满足
        if f and allow_filename(f.filename):
            # 保存上传文件至本地
            # 按照格式获取当前时间,从命名文件
            now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
            file_extension = f.filename.split('.')[-1]
            new_filename = now + '.' + file_extension
            file_path = './' + app.config['UPLOAD_FOLDER'] + '/' + new_filename
            f.save(file_path)

            # 进行预测,并显示图片
            img, obj_infos = detect(model_inited, opt, file_path)
            return render_template('upload_ok.html', det_result = obj_infos)
    return render_template('upload.html')

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=2222)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70

detect_for_flask.py

import argparse
from sys import platform

from models import *  # set ONNX_EXPORT in models.py
from utils.datasets import *
from utils.utils import *

'''
根据原始YoloV3中的detect.py,重写了检测函数,来适配flask
'''


def init_model():
    '''
    模型参数初始化
    :无输入参数
    :return: 完成初始的模型 和 opt设置
    '''
    # paraments config
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='*.cfg path')
    parser.add_argument('--names', type=str, default='data/coco.names', help='*.names path')
    parser.add_argument('--weights', type=str, default='weights/yolov3.weights', help='weights path')
    parser.add_argument('--output', type=str, default='output', help='output folder')  # detect result will be saved here
    parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
    parser.add_argument('--conf-thres', type=float, default=0.3, help='object confidence threshold')
    parser.add_argument('--iou-thres', type=float, default=0.6, help='IOU threshold for NMS')
    parser.add_argument('--device', default='cpu', help='device id (i.e. 0 or 0,1) or cpu')
    parser.add_argument('--classes', nargs='+', type=int, help='filter by class')
    parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
    parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
    opt = parser.parse_args()
    print(opt)

    # init paraments
    out, weights, save_txt = opt.output, opt.weights, opt.save_txt

    # Initialize
    device = torch_utils.select_device(device='cpu' if ONNX_EXPORT else opt.device)
    if not os.path.exists(out):
        os.makedirs(out)  # make new output folder

    # Initialize model
    model = Darknet(opt.cfg, opt.img_size)

    # Load weights
    attempt_download(weights)
    if weights.endswith('.pt'):  # pytorch format
        model.load_state_dict(torch.load(weights, map_location=device)['model'])
    else:  # darknet format
        load_darknet_weights(model, weights)

    return model, opt

def detect(model, opt, image_path):
    '''
    :param model: 完成初始化的模型
    :param opt: opt参数
    :param image_path:传入的图片地址 
    :param save_img: 是否保存图片
    :return: 完成定位后的结果
    '''
    # Eval mode
    model.to(opt.device).eval()
    # Save img?
    save_img = True

    # Process the upload image

    # read img
    img0 = cv2.imread(image_path)  # BGR
    assert img0 is not None, 'Image Not Found ' + image_path

    # Padded resize
    img = letterbox(img0, new_shape=opt.img_size)[0]

    # Convert
    img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
    img = np.ascontiguousarray(img)

    # Get names and colors
    names = load_classes(opt.names)
    colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]

    # Run inference
    t0 = time.time()

    img = torch.from_numpy(img).to(opt.device)
    img = img.float()  # uint8 to fp16/32
    img /= 255.0  # 0 - 255 to 0.0 - 1.0
    if img.ndimension() == 3:
        img = img.unsqueeze(0)
    with torch.no_grad():
        # Inference
        t1 = torch_utils.time_synchronized()
        pred = model(img)[0]
        t2 = torch_utils.time_synchronized()
        # print("pred:", pred)

        # Apply NMS
        pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)

        # Process detections
        for i, det in enumerate(pred):  # detections per image
            # 这是检测出来的所有object的,检测结果是一个二维list
            # 每一行存放的是一个obj的左上,右下四个坐标,置信度,类别
            # print("det", det)

            p, s = image_path, ''

            save_path = str(Path(opt.output) / Path(p).name)
            s += '%gx%g ' % img.shape[2:]  # print string
            # 若检测出了对象,则list不为空
            if det is not None and len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += '%g %ss, ' % (n, names[int(c)])  # add to string
                # 设置字典,写入每个目标数据
                obj_info_list = []
                # 遍历二维det中的每行,从而对每一个obj进行处理
                # Write results
                for *xyxy, conf, cls in det:
                    if opt.save_txt:  # Write to file
                        with open(save_path + '.txt', 'a') as file:
                            file.write(('%g ' * 6 + '\n') % (*xyxy, cls, conf))

                    if save_img:  # Add bbox to image
                        label = '%s %.2f' % (names[int(cls)], conf)
                        plot_one_box(xyxy, img0, label=label, color=colors[int(cls)]) # 参数xyxy中包含着bbox的坐标
                    # 记录单个目标的坐标,类别,置信度
                    sig_obj_info =('%s %g %g %g %g %g' ) % (names[int(cls)], *xyxy, conf)
                    print("sig_obj_info:", sig_obj_info)
                    obj_info_list.append(sig_obj_info)

            # Print time (inference + NMS)
            print('%sDone. (%.3fs)' % (s, t2 - t1))


            # Save results (image with detections)
            if save_img:
                # 两次保存
                # 1.永久保存检测结果,存入output文件夹
                cv2.imwrite(save_path, img0)
                # 2.暂存文件,用于显示
                cv2.imwrite('./static/temp.jpg', img0)

    print('Done. (%.3fs)' % (time.time() - t0))
    return img0, obj_info_list


if __name__ == '__main__':
    img_path = './data/samples/timg1.jpg'
    model_inited, opt = init_model()
    result,obj_infos = detect(model = model_inited, opt = opt, image_path=img_path)
    print(obj_infos)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160

五、项目截图

在这里插入图片描述
在这里插入图片描述

六、 参考与致谢

https://github.com/ultralytics/yolov3
https://blog.csdn.net/rain2211/article/details/105965313/

注:只是简单demo,没有写检测不到时候的处理,自己处理一下报错。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/558503
推荐阅读
相关标签
  

闽ICP备14008679号