当前位置:   article > 正文

yolov5使用flask部署至前端,实现照片\视频识别_yolov5与flask集成

yolov5与flask集成
大半年前初学yolo flask时,急需此功能,Csdn、Github、B站找到很多教程,效果并不是很满意。
近期做项目碰到类似需求,再度尝试,实现简单功能,分享下相关代码,仅学习使用,如有纰漏,望多包涵。

实现功能:

  1. 可更换权重文件(best.py)

  2. 上传图片并识别,可以点击图片放大查看

  3. 上传视频并识别

  4. 识别后的文件下载功能

效果图如上

文件结构如下:

project/
  static/

  空

  templates/
    index.html
    
 app.py
 

相关代码:

app.py

  1. import cv2
  2. import numpy as np
  3. import torch
  4. from flask import Flask, request, jsonify, render_template
  5. import base64
  6. import os
  7. from datetime import datetime
  8. from werkzeug.utils import secure_filename
  9. app = Flask(__name__)
  10. # 全局变量:模型和模型权重路径
  11. model = None
  12. # 提前加载模型
  13. # 提前加载模型
  14. def load_model():
  15. global model
  16. global new_filename
  17. # 拼接权重文件的完整路径
  18. model = torch.hub.load("E:\\pythonProject2\\flaskProject\\yolov5-master", "custom", path='weight/'+new_filename, source="local")
  19. # 路由处理图片检测请求
  20. @app.route("/predict_image", methods=["POST"])
  21. def predict_image():
  22. global model
  23. # 获取图像文件
  24. file = request.files["image"]
  25. # 读取图像数据并转换为RGB格式
  26. image_data = file.read()
  27. nparr = np.frombuffer(image_data, np.uint8)
  28. image = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED)
  29. results = model(image)
  30. image = results.render()[0]
  31. # 将图像转换为 base64 编码的字符串
  32. _, buffer = cv2.imencode(".png", image)
  33. image_str = base64.b64encode(buffer).decode("utf-8")
  34. # 获取当前时间,并将其格式化为字符串
  35. current_time = datetime.now().strftime("%Y%m%d%H%M%S")
  36. # 构建保存路径
  37. save_dir = "static"
  38. filename, extension = os.path.splitext(file.filename) # 获取上传文件的文件名和扩展名
  39. save_filename = f"{filename}_{current_time}{extension}"
  40. save_path = os.path.join(save_dir, save_filename)
  41. cv2.imwrite(save_path, image)
  42. return jsonify({"image": image_str})
  43. # 函数用于在视频帧上绘制检测结果
  44. def detect_objects(frame, model):
  45. results = model(frame)
  46. detections = results.pred[0] # 这里假设只有一张输入图片
  47. # 在帧上绘制检测结果
  48. for det in detections:
  49. # 获取边界框信息
  50. x1, y1, x2, y2, conf, class_id = det[:6]
  51. # 在帧上绘制边界框
  52. cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
  53. # 在帧上绘制类别和置信度
  54. label = f'{model.names[int(class_id)]} {conf:.2f}'
  55. cv2.putText(frame, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
  56. print(f'Confidence: {conf:.2f}')
  57. return frame
  58. # 路由处理视频检测请求
  59. @app.route("/predict_video", methods=["POST"])
  60. def predict_video():
  61. global model
  62. # 从请求中获取视频文件
  63. video_file = request.files["video"]
  64. # 保存视频到临时文件
  65. temp_video_path = "temp_video.mp4"
  66. video_file.save(temp_video_path)
  67. # 逐帧读取视频
  68. video = cv2.VideoCapture(temp_video_path)
  69. # 获取视频的帧率和尺寸
  70. fps = video.get(cv2.CAP_PROP_FPS)
  71. width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
  72. height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
  73. # 视频写入对象
  74. output_video_filename = f"output_video_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4"
  75. output_video_path = os.path.join("static", output_video_filename)
  76. fourcc = cv2.VideoWriter_fourcc(*"avc1") # 使用 H.264 编码器
  77. out_video = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
  78. # 逐帧处理视频并进行目标检测
  79. while True:
  80. ret, frame = video.read()
  81. if not ret:
  82. break
  83. # 进行目标检测
  84. detection_result = detect_objects(frame, model)
  85. # 将处理后的帧写入输出视频
  86. out_video.write(detection_result)
  87. # 释放视频对象
  88. video.release()
  89. out_video.release()
  90. return jsonify({"output_video_path": output_video_filename})
  91. @app.route("/upload_weight", methods=["POST"])
  92. def upload_weight():
  93. global new_filename
  94. # 获取上传的权重文件
  95. weight_file = request.files["weight"]
  96. # 获取上传文件的原始文件名
  97. original_filename = secure_filename(weight_file.filename)
  98. # 提取文件名和扩展名
  99. filename, extension = os.path.splitext(original_filename)
  100. # 构造新的文件名,加上当前时间戳
  101. current_time = datetime.now().strftime("%Y%m%d%H%M%S")
  102. new_filename = f"best_{current_time}.pt"
  103. # 拼接权重文件的保存路径
  104. save_path = os.path.join("E:\\pythonProject2\\flaskProject\\weight\\", new_filename)
  105. # 保存权重文件
  106. weight_file.save(save_path)
  107. # 加载模型
  108. load_model()
  109. return jsonify({"message": "Weight file uploaded successfully and model loaded"})
  110. @app.route("/")
  111. def index():
  112. return render_template("index.html")
  113. if __name__ == "__main__":
  114. app.run(debug=True)

index.html

  1. <!DOCTYPE html>
  2. <html lang="en">
  3. <head>
  4. <meta charset="UTF-8">
  5. <meta name="viewport" content="width=device-width, initial-scale=1.0">
  6. <title>Object Detection</title>
  7. <style>
  8. body {
  9. font-family: Arial, sans-serif;
  10. margin: 0;
  11. padding: 0;
  12. background-color: #f3f3f3;
  13. display: flex;
  14. justify-content: center;
  15. align-items: center;
  16. height: 100vh;
  17. flex-direction: column;
  18. }
  19. #content {
  20. text-align: center;
  21. max-width: 820px;
  22. margin-top: 20px;
  23. }
  24. h1 {
  25. color: #333;
  26. }
  27. h2 {
  28. color: #666;
  29. }
  30. input[type="file"] {
  31. margin-bottom: 10px;
  32. }
  33. .media-container {
  34. display: flex;
  35. max-width: 100%;
  36. margin-bottom: 20px;
  37. }
  38. .media-container:first-child {
  39. margin-right: 20px; /* 在第一个容器的右侧添加间隔 */
  40. }
  41. .media-container img,
  42. .media-container video {
  43. max-width: 100%;
  44. height: auto;
  45. }
  46. .original {
  47. width: 400px;
  48. overflow: hidden;
  49. }
  50. .processed {
  51. flex: 2; /* 右边容器占据剩余空间 */
  52. }
  53. button {
  54. padding: 10px 20px;
  55. background-color: #007bff;
  56. color: #fff;
  57. border: none;
  58. border-radius: 5px;
  59. cursor: pointer;
  60. margin-bottom: 10px;
  61. }
  62. /* 新增样式:模态框 */
  63. .modal {
  64. display: none; /* 默认隐藏 */
  65. position: fixed;
  66. z-index: 1;
  67. left: 0;
  68. top: 0;
  69. width: 100%;
  70. height: 100%;
  71. overflow: auto;
  72. background-color: rgba(0, 0, 0, 0.9); /* 半透明黑色背景 */
  73. }
  74. .modal-content {
  75. margin: auto;
  76. display: block;
  77. width: 80%;
  78. max-width: 800px;
  79. position: absolute;
  80. left: 50%;
  81. top: 50%;
  82. transform: translate(-50%, -50%);
  83. text-align: center; /* 居中显示图片 */
  84. }
  85. .close {
  86. color: #ccc;
  87. font-size: 36px;
  88. font-weight: bold;
  89. cursor: pointer;
  90. position: absolute;
  91. top: 10px;
  92. right: 10px;
  93. }
  94. .close:hover,
  95. .close:focus {
  96. color: #fff;
  97. text-decoration: none;
  98. }
  99. #downloadButton {
  100. padding: 10px 20px;
  101. background-color: #007bff;
  102. color: #fff;
  103. border: none;
  104. border-radius: 5px;
  105. cursor: pointer;
  106. margin-bottom: 10px;
  107. }
  108. /* 新增样式:响应式图片 */
  109. .modal-content img,
  110. .modal-content video {
  111. max-width: 100%;
  112. height: auto;
  113. }
  114. </style>
  115. </head>
  116. <body>
  117. <h2>上传权重文件</h2>
  118. <!-- 新增按钮用于触发上传权重文件 -->
  119. <button onclick="document.getElementById('weightFile').click()">选择权重文件</button>
  120. <input type="file" id="weightFile" accept=".pt" onchange="displaySelectedWeightFile()" style="display: none;">
  121. <br>
  122. <!-- 新增模态框 -->
  123. <div id="myModal" class="modal" onclick="closeModal()">
  124. <div class="modal-content" id="modalContent" onclick="stopPropagation(event)">
  125. <!-- 放大后的图片或视频将在这里显示 -->
  126. <span class="close" onclick="closeModal()">&times;</span>
  127. </div>
  128. </div>
  129. <div id="content">
  130. <h1>照片/视频检测</h1>
  131. <!-- 上传图片 -->
  132. <h2>上传图片</h2>
  133. <input type="file" id="imageFile" accept="image/*" onchange="displaySelectedImage()">
  134. <button onclick="uploadImage()">上传</button>
  135. <button id="downloadImageButton" onclick="downloadProcessedImage()">下载</button>
  136. <br>
  137. <div class="media-container">
  138. <div class="original media-container" onclick="enlargeImage()">
  139. <img id="uploadedImage" src="#" alt="Uploaded Image" style="display:none;">
  140. <button id="zoomInButton" style="display:none;">Zoom In</button>
  141. </div>
  142. <div class="processed media-container" onclick="enlargeImage2()">
  143. <img id="processedImage" src="#" alt="Processed Image" style="display:none;">
  144. </div>
  145. </div>
  146. <br>
  147. <!-- 上传视频 -->
  148. <h2>上传视频</h2>
  149. <input type="file" id="videoFile" accept="video/mp4,video/x-m4v,video/*" onchange="displaySelectedVideo()">
  150. <button onclick="uploadVideo()">上传</button>
  151. <button id="downloadButton" onclick="downloadProcessedVideo()">下载</button>
  152. <br>
  153. <div class="media-container">
  154. <div class="original media-container" >
  155. <video id="uploadedVideo" src="#" controls style="display:none;"></video>
  156. </div>
  157. <div class="processed media-container">
  158. <video id="processedVideo" controls style="display:none;"></video>
  159. </div>
  160. </div>
  161. <br>
  162. </div>
  163. <script>
  164. // 显示选择的权重文件
  165. function displaySelectedWeightFile() {
  166. var fileInput = document.getElementById('weightFile');
  167. var file = fileInput.files[0];
  168. console.log('Selected weight file:', file);
  169. // 上传权重文件
  170. uploadWeight(file);
  171. }
  172. // 上传权重文件
  173. function uploadWeight(file) {
  174. var formData = new FormData();
  175. formData.append('weight', file);
  176. fetch('/upload_weight', {
  177. method: 'POST',
  178. body: formData
  179. })
  180. .then(response => response.json())
  181. .then(data => {
  182. console.log('Upload weight response:', data);
  183. // 可以根据后端返回的响应进行相应的处理
  184. })
  185. .catch(error => console.error('Error:', error));
  186. }
  187. // 显示选择的图片并添加点击放大功能
  188. function displaySelectedImage() {
  189. var fileInput = document.getElementById('imageFile');
  190. var file = fileInput.files[0];
  191. var imageElement = document.getElementById('uploadedImage');
  192. imageElement.src = URL.createObjectURL(file);
  193. imageElement.style.display = 'inline';
  194. document.getElementById('zoomInButton').style.display = 'inline';
  195. }
  196. // 显示模态框并放大图片
  197. function enlargeImage() {
  198. var modal = document.getElementById('myModal');
  199. var modalImg = document.getElementById('modalContent');
  200. var img = document.getElementById('uploadedImage');
  201. modal.style.display = 'block';
  202. modalImg.innerHTML = '<img src="' + img.src + '">';
  203. }
  204. // 显示模态框并放大图片
  205. function enlargeImage2() {
  206. var modal = document.getElementById('myModal');
  207. var modalImg = document.getElementById('modalContent');
  208. var img = document.getElementById('processedImage');
  209. modal.style.display = 'block';
  210. modalImg.innerHTML = '<img src="' + img.src + '">';
  211. }
  212. // 显示选择的视频并添加点击放大功能
  213. function displaySelectedVideo() {
  214. var fileInput = document.getElementById('videoFile');
  215. var file = fileInput.files[0];
  216. var videoElement = document.getElementById('uploadedVideo');
  217. videoElement.src = URL.createObjectURL(file);
  218. videoElement.style.display = 'block';
  219. }
  220. // 上传图片并向后端发送请求
  221. function uploadImage() {
  222. var fileInput = document.getElementById('imageFile');
  223. var file = fileInput.files[0];
  224. var formData = new FormData();
  225. formData.append('image', file);
  226. fetch('/predict_image', {
  227. method: 'POST',
  228. body: formData
  229. })
  230. .then(response => response.json())
  231. .then(data => {
  232. var imageElement = document.getElementById('processedImage');
  233. imageElement.src = 'data:image/png;base64,' + data.image;
  234. imageElement.style.display = 'inline';
  235. document.getElementById('downloadImageButton').style.display = 'inline';
  236. })
  237. .catch(error => console.error('Error:', error));
  238. }
  239. // 下载处理后的图片
  240. function downloadProcessedImage() {
  241. var imageElement = document.getElementById('processedImage');
  242. var url = imageElement.src;
  243. var a = document.createElement('a');
  244. a.href = url;
  245. a.download = 'processed_image.png';
  246. document.body.appendChild(a);
  247. a.click();
  248. document.body.removeChild(a);
  249. }
  250. // 上传视频并向后端发送请求
  251. function uploadVideo() {
  252. var fileInput = document.getElementById('videoFile');
  253. var file = fileInput.files[0];
  254. var formData = new FormData();
  255. formData.append('video', file);
  256. fetch('/predict_video', {
  257. method: 'POST',
  258. body: formData
  259. })
  260. .then(response => response.json())
  261. .then(data => {
  262. var videoElement = document.getElementById('processedVideo');
  263. // 修改路径为正确的 Flask url_for 生成的路径
  264. videoElement.src = '{{ url_for("static", filename="") }}' + data.output_video_path;
  265. videoElement.style.display = 'block';
  266. var downloadButton = document.getElementById('downloadButton');
  267. downloadButton.style.display = 'block';
  268. })
  269. .catch(error => console.error('Error:', error));
  270. }
  271. // 下载处理后的视频
  272. function downloadProcessedVideo() {
  273. var videoElement = document.getElementById('processedVideo');
  274. var url = videoElement.src;
  275. var a = document.createElement('a');
  276. a.href = url;
  277. a.download = 'processed_video.mp4';
  278. document.body.appendChild(a);
  279. a.click();
  280. document.body.removeChild(a);
  281. }
  282. // 关闭模态框
  283. function closeModal() {
  284. var modal = document.getElementById('myModal');
  285. modal.style.display = 'none';
  286. }
  287. </script>
  288. </body>
  289. </html>

使用说明:

将index.html放入templates文件夹中

运行app.py

model = torch.hub.load("E:\\pythonProject2\\flaskProject\\yolov5-master", "custom", path='weight/'+new_filename, source="local")

注意此处加载模型路径更改为自己的

如果模型读取不到,csdn上有相关解决方法

此处的best.py上传路径也要更改成自己的

save_path = os.path.join("E:\\pythonProject2\\flaskProject\\weight\\", new_filename)

使用说明:

先上传本地的模型,上传成功后等待模型加载,再上传照片/视频

如有问题,可联系作者,随时讨论。

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号