赞
踩
TensorFlow+Django实现在线目标检测系统
第一次写博文,觉得不好的地方大家多担待,其实一开始我也没想要做这个项目的demo,开始我只是做了基于官网提供的模型的tensorflow的目标识别demo,自己在本机把代码梳理实现了对输入图像的目标检测(窃喜,自我感觉良好),然后我就想把这个demo给小伙伴们看看效果,再到后来就开始我项目过程了。
直接给代码,导入项目需要的包
# -*- coding: utf-8 -*- """ Created on Tue Feb 25 19:09:48 2020 @author: sbtithzy """ import numpy as np import os import six.moves.urllib as urllib import sys import tarfile import tensorflow as tf from collections import defaultdict from io import StringIO from matplotlib import pyplot as plt from PIL import Image from media.research.object_detection.utils import label_map_util from media.research.object_detection.utils import visualization_utils as vis_util
这里基本上是和官方demo差不多的,只是最后部分的utils包是本地做项目demo的相对路径而已。
下面部分是几个关键的函数
def Model():
MODEL_NAME = 'media/research/object_detection/ssd_mobilenet_v1_coco_11_06_2017'
MODEL_FILE = MODEL_NAME + '.tar.gz'
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
PATH_TO_LABELS = os.path.join('media/research/object_detection/data', 'mscoco_label_map.pbtxt')
return MODEL_FILE,PATH_TO_LABELS,PATH_TO_CKPT
这里是定义model用的,可以从官方网站下载,也可以自己通过coco数据集(百度网盘,提取码b3sf)训练然后保存下来。代码里的路径都是相对于我自己项目demo的路径,请大家根据自己的情况修改,这里返回的几个参数我说明一下model_file就是相对的model文件路径,PATH_TO_LABELS就是对应训练图片目标的标签,PATH_TO_CKPT这个就是demo中用到的模型了。
def load_image_into_numpy_array(image):
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape(
(im_height, im_width, 3)).astype(np.uint8)
这部分就是转换图片。
def Picture_process(model_file,path_to_ckpt,test_image_path,category_index): tar_file = tarfile.open(model_file) for file in tar_file.getmembers(): file_name = os.path.basename(file.name) if 'frozen_inference_graph.pb' in file_name: tar_file.extract(file, os.getcwd()) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(path_to_ckpt, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') for image_path in test_image_path: image = Image.open(image_path) # the array based representation of the image will be used later in order to prepare the # result image with boxes and labels on it. image_np = load_image_into_numpy_array(image) # Expand dimensions since the model expects images to have shape: [1, None, None, 3] image_np_expanded = np.expand_dims(image_np, axis=0) # Actual detection. (boxes, scores, classes, num) = sess.run([detection_boxes, detection_scores, detection_classes, num_detections],feed_dict={image_tensor: image_np_expanded}) # Visualization of the results of a detection. vis_util.visualize_boxes_and_labels_on_image_array(image_np,np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),category_index, use_normalized_coordinates=True,line_thickness=4) #plt.figure(figsize=image_size) #plt.imshow(image_np) return image_np,classes,category_index
这部分的代码有点长,分解一下,传入的参数(根据官网demo改造的),主要就是前面model函数返回的几个必要的参数,category_index可以理解为一个嵌套的字典,返回的最后识别的目标分类信息。tarfile打开model_file这里是一个压缩包,然后遍历文件下的文件,找到并保留frozen_inference_graph.pb这玩意。下面部分就是加载model到内存中,下面的代码我就不过多解读了百度也能找到很多解读博客(害怕误导大家)。
def main(image_name):
TEST_IMAGE_PATHS = [ os.path.join('media/research/object_detection/test_images',image_name)]
MODEL_FILE,PATH_TO_LABELS,PATH_TO_CKPT = Model()
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
image_np,label,category_index = Picture_process(MODEL_FILE,PATH_TO_CKPT,TEST_IMAGE_PATHS,category_index)
image_np = Image.fromarray(np.uint8(image_np))
ext = os.path.splitext(image_name)[1]
name_result = os.path.splitext(image_name)[0]
image_np.save('media/research/object_detection/test_images/'+name_result+'_result'+ext)
这里定义一个main函数,是为了把整个识别的流程串起来。我要解释一下是首先main函数我传了一个参数,大家肯定觉得很奇怪,这里传的是image_name,是为了我在前端通过django实现图像上传,然后调用识别目标主函数,所以为了方便只是传了一个图像名称,路径直接写死了,当然这里大家也可以根据自己的实际情况来改,在保存的时候加了一个result以便于区分返回的结果图像。
第一部分的代码就讲完了,主要是在官网demo的基础上做了集成和修改。
其实我也是一个django的新手[捂脸],这里需要感谢这位同学,如果对于django不是很熟悉的可以参考一下。下面直接上代码,大家需要关注的几个点要明确一下,首先项目需要在前端实现图像的上传,就涉及到图片存储的问题,然后要调目标识别的函数,最后再生成一个结果图片。项目中我做了2个页面,一个是上传页面,还有一个展示页面。
这里我主要说一下我踩坑的地方view.py文件
# -*- coding: utf-8 -*- from django.shortcuts import render,render_to_response from django import forms from django.http import HttpResponse from .models import user,IMG from django.views.decorators.csrf import csrf_exempt from django.http import HttpResponseRedirect from media.research.object_detection import object_detection_spyder_test as ods # Create your views here. class UserForm(forms.Form): username = forms.CharField(required=False) headImg = forms.ImageField(required=False) @csrf_exempt def index(request): if request.method == "POST": uf = UserForm(request.POST,request.FILES) if uf.is_valid(): #print(request.FILES) #获取表单信息request.FILES是个字典 User = user(headImg=request.FILES['file']) #保存在服务器 User.save() global name name = str(User.headImg).split('/')[-1] ods.main(name) return HttpResponse('识别成功,请查看识别结果!') return render(request, 'blog/index.html') def show_picture(request): import os filename='../media/research/object_detection/test_images/' ext = os.path.splitext(name)[1] name_result = os.path.splitext(name)[0] context={'name':filename+name_result+'_result'+ext} return render(request,'blog/Welcome.html',context)
先说一下index函数,有几个点,global name作用是为了把上传的图像名称传给目标识别程序ods.main(name),需要把前面保存的object_detection_spyder_test.py文件导进来。还有一个作用是为了实现在show_picture函数中传递图像名称,此处你可以品,细品。说道传递图像名称就不得不提到models.py文件了
from django.db import models
from django.contrib.auth.models import User
# # Create your models here.
from system.storage import ImageStorage
class user(models.Model):
headImg = models.ImageField(upload_to='research/object_detection/test_images/',storage=ImageStorage())
username = models.CharField(max_length=100)
def __str__(self):
return self.headImg
因为要修改上传图像名称,这里我导入from system.storage import ImageStorage参考了网上一些大佬的做法,直接上代码
# -*- coding: UTF-8 -*- from django.core.files.storage import FileSystemStorage from django.http import HttpResponse class ImageStorage(FileSystemStorage): from django.conf import settings def __init__(self, location=settings.MEDIA_ROOT, base_url=settings.MEDIA_URL): # 初始化 super(ImageStorage, self).__init__(location, base_url) # 重写 _save方法 def _save(self, name, content): #name为上传文件名称 import os, time, random # 文件扩展名 ext = os.path.splitext(name)[1] # 文件目录 d = os.path.dirname(name) # 定义文件名,年月日时分秒随机数 fn = time.strftime('%Y%m%d%H%M%S') fn = fn + '_%d' % random.randint(0,100) # 重写合成文件名 name = os.path.join(d, fn + ext) # 只保留一张图片 if self.exists(name): self.delete(name) # 调用父类方法 return super(ImageStorage, self)._save(name, content)
我做了一些修改,判读文件是否重名,大家肯定觉得很奇怪,你的文件名称生成都是有随机数在里面,怎么可能重名!我也觉得不会重名!其实可以注释掉,原因是我最开始的时候想法是服务器只保留一张图片,好吧,怪我当初太幼稚。
继续回到刚才models.py文件,为了把文件名称返回给main(),所以在定义models的时候需要返回一个参数headImg,这个参数可以获取到文件名信息。show_picture这个函数为了返回图像保存信息给Welcome.html,没什么需要过多解释的。
下面就是把对应的函数设置路由信息了urls.py
# -*- coding:utf-8 -*- from django.urls import path from django.views.static import serve from . import views from django.contrib import admin from django.conf.urls import url from django.conf.urls.static import static from django.conf import settings # 将url传入view.index模块中, index类别名name #urlpatterns = [path(r'', views.index, name='index'), ] from django.views.generic.base import RedirectView urlpatterns = [path(r'', views.index, name='index'), path(r'mysite/blog/templates/blog/', views.show_picture, name='show_picture'),url(r'^favicon.ico$',RedirectView.as_view(url=r'/static/favicon.ico',permanent=True)), ]+static(settings.MEDIA_URL,document_root=settings.MEDIA_ROOT)
这部分的代码也没有过多需要解读的,说一下name的作用,是为了在html代码中关联对应的功能。
剩下就是最后一部分了html,一部分是展示图片部分
<!DOCTYPE HTML>
{% load staticfiles %}
<!-- index.html -->
<html>
<head><head>
<body>
<h1>The Result Image</h1>
<img src="{% static name%}">
</body>
</html>
src后面接着模糊查询name的信息。
还有就是图像上传的代码
<div class="header clearfix">
<div class="left">
<a target="_blank" href="{% url 'show_picture' %}">查看识别结果</a>
</div>
</div>
当时这个坑我是踩了好几天,说明一下href说白了就是去找你要的链接,url就是我们关联的路由,然后关联上show_picture这个函数,简单吧,简单吧,原谅我是个新手。其他的代码可以参考我上面说到的那位同学,再次感谢这位同学。
项目里面涉及的踩坑点也基本上说的差不多,下面附上对应的Github地址
再说一下项目思路:前端页面提供图片上传,上传以后调用目标识别算法,算法给出预测输入,最后将结果呈现在前端页面上 在完成的过程中踩过很多坑,详细可以加我QQ咨询:965865965,觉得该项目对你有用的可以在github上点个star
下面是效果图,整个识别的过程大概耗时13s,主要时间花在了识别的过程中,可能自己电脑的性能有关系。整个项目可以提高的地方也有很多,比如说django部分,目标识别效率,识别精准度方面等等,欢迎大家留言交流。
PS:用我的AI大师码 0660 在滴滴云上购买GPU/vGPU/机器学习产品可享受9折优惠。点击前往滴滴云官网
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。