赞
踩
作为一名刚实习一个月的研二小白,好不容易改进模型,训练模型,成功预测了结果,老板说:“你去把这个模型部署一下|”,我问:“怎么部署?”,老板说:“我要知道还问你???”
我简直心里就是和这张图一样
然后打开了网页搜索栏,甚至连搜什么都不知道,弱弱的在搜索栏打入:AI模型部署。
结果出来的五花八门的,各种各样的,看的我眼花缭乱,甚至怀疑人生,我到底在干什么?
好啦,吐槽结束,我们进入正题哈哈哈哈。
很多同学一定会有疑问,模型要部署,怎么部署?部署在哪?应该从何做起?
这篇文章从yolov7和resnet152的角度来说明,如何将训练好的模型布置到远程linux服务器上,让用户可以直接访问网址,就可以使用我们的模型进行计算,并返回给用户预测结果。
以我自己的模型举例子,我的模型主要是使用yolov7去对茅台酒图片进行目标检测,检测到的结果进行裁剪,然后将裁剪后的图片送入resnet152进行分类,分别有三个模型去检测茅台酒的大商标、商标里面的侍女头部、以及瓶子后面的海洋标志,如下图所示:
第一张图是茅台酒的大商标:
第二张图是侍女的头部:
第三张是海洋标志:
关于模型训练部分和模型预测的部分可以参考我之前的文章:
然后将文件夹整理为这种形式:
强烈建议将文件夹放入一个文件夹内,并整理成这种形式:
这样我们后面部署的时候会比较方便,可以直接调用文件夹中我们定义的函数。
这里的_int_.py文件是空的,里面只有一个#字符,为了后面使用相对路径去调用包而设置的,这里的相对路径的意思是,你如果运行一个脚本,那么当前脚本属于_main_属性,将来调用的时候只能使用相对路径去调用那些不属于这个文件夹的包,就比如在同一个文件夹nets中,有很多函数将来被调用:
这里的yolo.py文件里都是相对路径调用的:
但如果在nets这个文件夹中你运行了一个脚本,这里就会出错。
如果你将这个main.py脚本放在外面去调用这个nets里面的函数,就可以正常运行。详情可以参考这篇文章:
python - Relative imports for the billionth time - Stack Overflow
这篇文章将相对路径import讲的非常清楚,如果想深究的同学可以参考一下。
然后我们来看main.py函数,代码如下:
- from flask import Flask, request, send_file, jsonify
- from flymaotai_yolov7_resnet152.predict_flymaotai import predict_flymaotai
- from flymaotai_head_yolov7_resnet152.predict_maotaihead import predict_maotaihead
- from Ocean_yolov7_resnet152.predict_Ocean import predict_Ocean
-
- app = Flask(__name__)
-
- '''
- 这里的网址:
- http://主机ip:5000/predict_maotai/
- http://主机ip:5000/predict_maotaihead/
- http://主机ip:5000/predict_ocean/
- 分别对应了茅台大图标,茅台侍女头部,海洋标志的检测地址,这个网址后期对应着前端接口
- '''
- @app.route('/predict_maotai/', methods=['POST'])
- def predict_maotai():
- if 'image' not in request.files:
- return 'No image provided', 400
-
- image = request.files['image']
- result = predict_flymaotai(image)
- return jsonify(result=result)
- # send_file(result, mimetype='image/jpeg')
-
- @app.route('/predict_maotaihead/', methods=['POST'])
- def predict_maotai_head():
- if 'image' not in request.files:
- return 'No image provided', 400
-
- image = request.files['image']
- result = predict_maotaihead(image)
- return jsonify(result=result)
- # send_file(result, mimetype='image/jpeg')
-
- @app.route('/predict_ocean/', methods=['POST'])
- def predict_Ocean_():
- if 'image' not in request.files:
- return 'No image provided', 400
-
- image = request.files['image']
- result = predict_Ocean(image)
- return jsonify(result=result)
- # send_file(result, mimetype='image/jpeg')
-
- if __name__ == '__main__':
- app.debug = True
- app.run(host='0.0.0.0', port=5000)
- '''
- 这里添加一些我对前端和网页的理解(不必理会):
- 这里的@app.route()实际上是在http://0.0.0.0:5000/predict_maotai/这个网页上进行的,将来前端可以做接口对应的就是这里的这个网址
- 比如说这个网址将来前端做一个按键,点进去就到了发送图片来使用算法检测的网页了
- 这种@app.route可以有多个,对应了不同网址后缀名称的作用
- '''
这里的@app.route('/你自己定义的名字/', methods=['POST']),相当于你自己创建了一个服务器网址,到时候用户会根据这个网址来给服务器post一个图片,服务器会调用本地的模型去计算,返回计算结果给用户,这个网址你可以根据自己的需求创建多个。
这个main.py文件在linux上运行时,会产生以下结果:
这个意思就是服务器已经成功运行了。
接下来,我们来看这段代码import的三个预测函数是什么样子的:
这是预测茅台大商标的预测函数:
- from .resnet_for_flymaotai.classification import Classification
- def predict_flymaotai(image):
- mode = "predict_onnx"
- crop = True
- count = False
- pic = 0
- dir_origin_path = "img/"
- dir_save_path = "img_out/"
- yolo = YOLO_ONNX()
- classfication = Classification()
- try:
- image = Image.open(image)
- except:
- print('Open Error! Try again!')
- else:
- image, crop_image = yolo.detect_image(image, crop=crop, count=count, pic=pic)
-
- class_name = classfication.detect_image(crop_image)
- print(class_name)
-
- return class_name
这个函数将输入的图像先送入yolov7目标检测,返回两个值image和crop_image(这里要对yolo里面的函数稍加修改)
修改部分如下:
- def detect_image(self, image, pic, crop = False, count = False):
- image_shape = np.array(np.shape(image)[0:2])
- #---------------------------------------------------------#
- # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
- # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
- #---------------------------------------------------------#
- image = cvtColor(image)
-
- image_data = self.resize_image(image, self.input_shape, True)
- #---------------------------------------------------------#
- # 添加上batch_size维度
- # h, w, 3 => 3, h, w => 1, 3, h, w
- #---------------------------------------------------------#
- image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
-
- input_feed = self.get_input_feed(image_data)
- outputs = self.onnx_session.run(output_names=self.output_name, input_feed=input_feed)
-
- feature_map_shape = [[int(j / (2 ** (i + 3))) for j in self.input_shape] for i in range(len(self.anchors_mask))][::-1]
- for i in range(len(self.anchors_mask)):
- outputs[i] = np.reshape(outputs[i], (1, len(self.anchors_mask[i]) * (5 + self.num_classes), feature_map_shape[i][0], feature_map_shape[i][1]))
-
- outputs = self.bbox_util.decode_box(outputs)
- #---------------------------------------------------------#
- # 将预测框进行堆叠,然后进行非极大抑制
- #---------------------------------------------------------#
- results = self.bbox_util.non_max_suppression(np.concatenate(outputs, 1), self.num_classes, self.input_shape,
- image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
-
- if results[0] is None:
- return image
-
- top_label = np.array(results[0][:, 6], dtype = 'int32')
- top_conf = results[0][:, 4] * results[0][:, 5]
- top_boxes = results[0][:, :4]
-
- #---------------------------------------------------------#
- # 设置字体与边框厚度
- #---------------------------------------------------------#
- font = ImageFont.truetype(font='/home/ubuntu/workfile/flymaotai_yolov7_resnet152/yolov7_flymaotai_CA/model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
- thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
-
- # ---------------------------------------------------------#
- # 计数
- # ---------------------------------------------------------#
- if count:
- print("top_label:", top_label)
- classes_nums = np.zeros([self.num_classes])
- for i in range(self.num_classes):
- num = np.sum(top_label == i)
- if num > 0:
- print(self.class_names[i], " : ", num)
- classes_nums[i] = num
- print("classes_nums:", classes_nums)
- # ---------------------------------------------------------#
- # 是否进行目标的裁剪
- # ---------------------------------------------------------#
- # 这里的top_boxs是一个图检测到的目标框
- if crop:
- for i, c in list(enumerate(top_boxes)):
- top, left, bottom, right = top_boxes[i]
- top = max(0, np.floor(top).astype('int32'))
- left = max(0, np.floor(left).astype('int32'))
- bottom = min(image.size[1], np.floor(bottom).astype('int32'))
- right = min(image.size[0], np.floor(right).astype('int32'))
-
- dir_save_path = "img_crop"
- if not os.path.exists(dir_save_path):
- os.makedirs(dir_save_path)
- crop_image = image.crop([left, top, right, bottom])
- crop_image.save(os.path.join(dir_save_path, "crop_pic" + str(pic) + '_box' + str(i) + ".png"),
- quality=95, subsampling=0)
- print("save crop_pic" + str(pic) + "_box" + str(i) + ".png to " + dir_save_path)
-
-
- #---------------------------------------------------------#
- # 图像绘制
- #---------------------------------------------------------#
- for i, c in list(enumerate(top_label)):
- predicted_class = self.class_names[int(c)]
- box = top_boxes[i]
- score = top_conf[i]
-
- top, left, bottom, right = box
-
- top = max(0, np.floor(top).astype('int32'))
- left = max(0, np.floor(left).astype('int32'))
- bottom = min(image.size[1], np.floor(bottom).astype('int32'))
- right = min(image.size[0], np.floor(right).astype('int32'))
-
- label = '{} {:.2f}'.format(predicted_class, score)
- draw = ImageDraw.Draw(image)
- label_size = draw.textsize(label, font)
- label = label.encode('utf-8')
- print(label, top, left, bottom, right)
-
- if top - label_size[1] >= 0:
- text_origin = np.array([left, top - label_size[1]])
- else:
- text_origin = np.array([left, top + 1])
-
- for i in range(thickness):
- draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
- draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
- draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
- del draw
-
- return image, crop_image
这里的:detect_image(self, image, pic, crop = False, count = False)
我加入了pic变量,将来裁剪出来的图片不会覆盖,并且return了image和crop_image两个结果
然后让我们返回到上一段预测函数的代码,可以看到将crop_image送入了resnet中计算,返回了一个class_name(最终的标签)。
然后是预测侍女头部的预测函数代码如下:
- '''
- 注意,这个脚本是我自己写的代码,主要是调用了yolo——onnx模型去做前向推理,将其写成了一个predict函数,将输出修改了一下,改成了cropimages
- 以便于后面进行resnet152的图片分类。
- '''
-
- import time
- import cv2
- import numpy as np
- from PIL import Image
- from .yolov7_flymaotai_head_CA.yolo import YOLO, YOLO_ONNX
- from .resnet152_for_head.classification import Classification
- def predict_maotaihead(image):
- mode = "predict_onnx"
- crop = True
- count = False
- pic = 0
- dir_origin_path = "img/"
- dir_save_path = "img_out/"
- yolo = YOLO_ONNX()
- classfication = Classification()
- try:
- image = Image.open(image)
- except:
- print('Open Error! Try again!')
- else:
- image, crop_image = yolo.detect_image(image, crop=crop, count=count, pic=pic)
-
- class_name = classfication.detect_image(crop_image)
- print(class_name)
-
- return class_name
海洋标志的预测函数代码如下:
- from PIL import Image
- from .yolov7_Ocean_CA.yolo import YOLO, YOLO_ONNX
- from .resnet152_for_Ocean.classification import Classification
- def predict_Ocean(image):
- mode = "predict_onnx"
- crop = True
- count = False
- pic = 0
- dir_origin_path = "img/"
- dir_save_path = "img_out/"
- yolo = YOLO_ONNX()
- classfication = Classification()
- try:
- image = Image.open(image)
- except:
- print('Open Error! Try again!')
- else:
- image, crop_image = yolo.detect_image(image, crop=crop, count=count, pic=pic)
-
- class_name = classfication.detect_image(crop_image)
- print(class_name)
-
- return class_name
然后这里我建议大家在部署的时候,如果你使用了加入_int_.py的方法想去使用相对路径调用文件,那么在模型里面要修改为绝对路径,防止到时候因为路径找不到而引发报错(作者因为这件事情焦头烂额,忙碌了两天)。
当服务器成功启动完毕之后,我们就可以试着给服务器发送post请求了,首先下载postman这个软件:
下载好之后在这个页面:
输入我们main.py里面所创建的地址以及端口号,这里的地址由于我是在远程linux服务器上开启的服务,所以我的IP就是服务器的IP,不是你在linux中ifconfig查到的地址!不是你在linux中ifconfig查到的地址!不是你在linux中ifconfig查到的地址!重要的话说三遍!
给这个网址发送一个post请求,我这里是将本地的一张图片发送到服务器了,它接收到之后我们可以去查看linux服务器上已经开始计算了:
可以看到最终通过两个模型配合计算,返回了一个标签:normal(真),然后我们查看postman上的结果:
这里我们发送图片post请求后,得到的结果是“normal”,到此一个完整的flask部署在服务器上的过程就结束了。
网上的教程五花八门,又讲不清重点,(落泪.jpg),希望很多刚入门的同学看了我这篇文章可以给予你们帮助,感谢阅读!给个赞吧!(感谢.jpg)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。