当前位置:   article > 正文

将YOLOv5变成接口使用_yolov5改写api

yolov5改写api

将YOLOv5改造成接口使用

本文仅当做给自己的备忘录,照抄博主@uestcwhc
原文地址:
https://blog.csdn.net/HC_wood/article/details/119736231

实现了:
t=detectapi(weights)
results,names=t.detect(source)
其中参数 weights是权重文件的路径。参数source是一个列表,列表的每个元素是由cv2的读取的图片。返回值results是一个列表。列表的元素个数为source的元素个数,每个元素为每张图片的处理结果。每张图片的处理结果有两个,一个是一张在原图片中画框标识物品的cv2图片。另一个是一个列表,这个列表的元素个数等于本图片探测到的物品数量。元素为这个物品的信息:(物品在names中的引索,[物品的位置x1,y1,x2,y2],置信度)。返回值names为物品字典。
应用如下:打开摄像头,实时探测目标物品

import cv2
import detect
cap=cv2.VideoCapture(0)
a=detect.detectapi(weights='weights/yolov5s.pt')
while True:

    rec,img = cap.read()

    result,names =a.detect([img])
    img=result[0][0] #第一张图片的处理结果图片
    '''
    for cls,(x1,y1,x2,y2),conf in result[0][1]: #第一张图片的处理结果标签。
        print(cls,x1,y1,x2,y2,conf)
        cv2.rectangle(img,(x1,y1),(x2,y2),(0,255,0))
        cv2.putText(img,names[cls],(x1,y1-20),cv2.FONT_HERSHEY_DUPLEX,1.5,(255,0,0))
    '''
    cv2.imshow("vedio",img)

    if cv2.waitKey(1)==ord('q'):
        break
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

下面将detect.py做如下新增,原来的代码不要删

# 增加运行参数,原来的参数是通过命令行解析对象提供的,这里改为由调用者在代码中提供。需要一个
# 大体上完成一样功能的参数对象。
# 我想要的功能是传一组由cv2读取的图片,交给api,然后得到一组打上标签的图片,以及每张图片对应的标签类别引索,位置信息,置信度的信息,还有类别名称字典
# 要实现这个功能,需要权重文件,输入文件两个参数,其他参数与原代码命令行默认参数保持一致就行。
class simulation_opt:# 参数对象。

    def __init__(self,weights,img_size=640,conf_thres=0.25,iou_thres=0.45,device='',view_img=False,
                 classes=None,agnostic_nms=False,augment=False,update=False,exist_ok=False):
        self.weights=weights
        self.source=None
        self.img_size=img_size
        self.conf_thres=conf_thres
        self.iou_thres=iou_thres
        self.device=device
        self.view_img=view_img
        self.classes=classes
        self.agnostic_nms=agnostic_nms
        self.augment=augment
        self.update=update
        self.exist_ok=exist_ok

#增加一个新类,这个新类是在原来detect函数上进行删减。可以先复制原来的detect函数代码,再着手修改
class detectapi:
    def __init__(self,weights,img_size=640):
        # 构造函数中先做好必要的准备,如初始化参数,加载模型
        ''' 删掉
        source, weights, view_img, save_txt, imgsz = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size
        webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
        ('rtsp://', 'rtmp://', 'http://'))
        ''' #改为
        self.opt=simulation_opt(weights=weights,img_size=img_size)
        weights, imgsz= self.opt.weights, self.opt.img_size

        ''' 删掉
        # Directories
        #save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))  # increment run
        #(save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True)  # make dir
        '''
    # Initialize
        set_logging()
        self.device = select_device(self.opt.device)
        self.half = self.device.type != 'cpu'  # half precision only supported on CUDA

    # Load model
        self.model = attempt_load(weights, map_location=self.device)  # load FP32 model
        self.stride = int(self.model.stride.max())  # model stride
        self.imgsz = check_img_size(imgsz, s=self.stride)  # check img_size
        if self.half:
            self.model.half()  # to FP16

    # Second-stage classifier
        self.classify = False
        if self.classify:
            self.modelc = load_classifier(name='resnet101', n=2)  # initialize
            self.modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=self.device)['model']).to(self.device).eval()
        '''
        self.names,和self.colors是由后面的代码拉到这里来的。names是类别名称字典,colors是画框时用到的颜色。
        '''
    # read names and colors
        self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
        self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names]


    def detect(self,source): # 使用时,调用这个函数
        if type(source)!=list:
                raise TypeError('source must be a list which contain  pictures read by cv2')


        '''删掉
        if webcam:
        view_img = check_imshow()
        cudnn.benchmark = True  # set True to speed up constant image size inference
        dataset = LoadStreams(source, img_size=imgsz, stride=stride)
        else:
        save_img = True
        dataset = LoadImages(source, img_size=imgsz, stride=stride)

        
        '''# 改为
        # Set Dataloader
        dataset = MyLoadImages(source, img_size=self.imgsz, stride=self.stride)
        # 原来是通过路径加载数据集的,现在source里面就是加载好的图片,所以数据集对象的实现要
        # 重写。修改代码后附。在utils.dataset.py上修改。

        '''移动到构造方法末尾。names是类别名称字典,colors是画框时用到的颜色。
        names = model.module.names if hasattr(model, 'module') else model.names
        colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
        '''
    # Run inference
        if self.device.type != 'cpu':
            self.model(torch.zeros(1, 3, self.imgsz, self.imgsz).to(self.device).type_as(next(self.model.parameters())))  # run once
        result=[]
        ''' 删掉
        for path, img, im0s, vid_cap in dataset: 因为不用保存,所以path可以不要,因为不处理视频,所以vid_cap不要。
        ''' #改为
        for img, im0s in dataset:
            img = torch.from_numpy(img).to(self.device)
            img = img.half() if self.half else img.float()  # uint8 to fp16/32
            img /= 255.0  # 0 - 255 to 0.0 - 1.0
            if img.ndimension() == 3:
                img = img.unsqueeze(0)

            # Inference
            # t1 = time_synchronized() #计算预测用时的,可以不要
            pred = self.model(img, augment=self.opt.augment)[0]

            # Apply NMS
            pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, classes=self.opt.classes, agnostic=self.opt.agnostic_nms)
            # t2 = time_synchronized() #计算预测用时的,可以不要

            # Apply Classifier
            if self.classify:
                pred = apply_classifier(pred, self.modelc, img, im0s)


            '''删掉
             for i, det in enumerate(pred):  # detections per image
            if webcam:  # batch_size >= 1
                p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
            else:
                p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)

            p = Path(p)  # to Path
            save_path = str(save_dir / p.name)  # img.jpg
            txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')  # img.txt
            s += '%gx%g ' % img.shape[2:]  # print string
            gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += f"{n} {names[int(c)]}{'s' * (n > 1)}, "  # add to string

                # Write results
                for *xyxy, conf, cls in reversed(det):
                    if save_txt:  # Write to file
                        xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                        line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh)  # label format
                        with open(txt_path + '.txt', 'a') as f:
                            f.write(('%g ' * len(line)).rstrip() % line + '\n')

                    if save_img or view_img:  # Add bbox to image
                        label = f'{names[int(cls)]} {conf:.2f}'
                        plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
            ''' # 改为
            # Process detections
            det=pred[0] #原来的情况是要保持图片,因此多了很多关于保持路径上的处理。另外,pred
            # 其实是个列表。元素个数为batch_size。由于对于我这个api,每次只处理一个图片,
            # 所以pred中只有一个元素,直接取出来就行,不用for循环。
            im0 = im0s.copy() # 这是原图片,与被传进来的图片是同地址的,需要copy一个副本,否则,原来的图片会受到影响
            # s += '%gx%g ' % img.shape[2:]  # print string
            # gn = torch.tensor(im0.shape)[[1, 0, 1, 0]]  # normalization gain whwh
            result_txt = []
            # 对于一张图片,可能有多个可被检测的目标。所以结果标签也可能有多个。
            # 每被检测出一个物体,result_txt的长度就加一。result_txt中的每个元素是个列表,记录着
            # 被检测物的类别引索,在图片上的位置,以及置信度
            if len(det):
                # Rescale boxes from img_size to im0 size
                det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()

                # Print results
                '''
                for c in det[:, -1].unique():
                    n = (det[:, -1] == c).sum()  # detections per class
                    s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, "  # add to string
                '''
                # Write results

                for *xyxy, conf, cls in reversed(det):

                    #xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist()  # normalized xywh
                    line = (int(cls.item()), [int(_.item()) for _ in xyxy], conf.item())  # label format
                    result_txt.append(line)
                    label = f'{self.names[int(cls)]} {conf:.2f}'
                    plot_one_box(xyxy, im0, label=label, color=self.colors[int(cls)], line_thickness=3)
            result.append((im0,result_txt)) # 对于每张图片,返回画完框的图片,以及该图片的标签列表。
        return result, self.names

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181

下面对 yolov5/utils/dataset.py 修改,直接把下面代码增加到dataset.py即可,其他代码不用动。

class MyLoadImages:  # for inference
    def __init__(self, path, img_size=640, stride=32):
        for img in path:
            if type(img)!=np.ndarray or len(img.shape)!=3:
                raise TypeError('there is a object which is not a picture read by cv2 in source')

        '''
        p = str(Path(path).absolute())  # os-agnostic absolute path
        if '*' in p:
            files = sorted(glob.glob(p, recursive=True))  # glob
        elif os.path.isdir(p):
            files = sorted(glob.glob(os.path.join(p, '*.*')))  # dir
        elif os.path.isfile(p):
            files = [p]  # files
        else:
            raise Exception(f'ERROR: {p} does not exist')

        images = [x for x in files if x.split('.')[-1].lower() in img_formats]
        videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
        ni, nv = len(images), len(videos)
        '''
        self.img_size = img_size
        self.stride = stride
        self.files = path
        self.nf = len(path)
        #self.video_flag = [False] * ni + [True] * nv
        self.mode = 'image'
        #if any(videos):
            #self.new_video(videos[0])  # new video
        #else:
            #self.cap = None
        #assert self.nf > 0, f'No images or videos found in {p}. ' \
                            #f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'

    def __iter__(self):
        self.count = 0
        return self

    def __next__(self):
        if self.count == self.nf:
            raise StopIteration
        path = self.files[self.count]
        '''
        if self.video_flag[self.count]:
            # Read video
            self.mode = 'video'
            ret_val, img0 = self.cap.read()
            if not ret_val:
                self.count += 1
                self.cap.release()
                if self.count == self.nf:  # last video
                    raise StopIteration
                else:
                    path = self.files[self.count]
                    self.new_video(path)
                    ret_val, img0 = self.cap.read()

            self.frame += 1
            print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
        '''
        # Read image
        self.count += 1
        #img0 = cv2.imread(path)  # BGR
        #assert img0 is not None, 'Image Not Found ' + path
        #print(f'image {self.count}/{self.nf} {path}: ', end='')

        # Padded resize
        img = letterbox(path, self.img_size, stride=self.stride)[0]

        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)

        return img, path
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74

效果如下:
@uestcwhc 博主的图片
@uestcwhc博主的图片

在实际的使用中,该博主并没有用摄像头,而是通过rtmp拉取视频流的方式读取的数据,会有10秒推理延迟,为解决这个问题,新建了缓存队列,将读取到的数据存入队列中,保证入队和出队速度相同(相当于一个滑动窗口),yolov5每隔0.01秒在队头取一帧用来推理。从而防止出现动态延迟。
代码如下:
新建demo.py,把刚刚加了代码的detect.py导入进来


import cv2
import multiprocessing as mp
import detect
import time

def image_put(q, ip, port, name):
    cap = cv2.VideoCapture("rtmp://localhost:1935/live/movie")
    if cap.isOpened():
        print(name)

    while True:
        q.put(cap.read()[1])
        q.get() if q.qsize() > 1 else time.sleep(0.01)
        #print("555" * 25) if cap.read()[0] == False else print(" ")

def get_frames():

    camera_ip, camera_port, camera_name = "192.168.2.119", "554", "stream0"

    mp.set_start_method(method='spawn')  # init
    queue = mp.Queue(maxsize=2)
    processes = mp.Process(target=image_put, args=(queue, camera_ip, camera_port, camera_name)),
    [process.start() for process in processes]
    while True:
        yield queue.get()
        
        
def main():
    a=detect.detectapi(weights='runs/train/exp24/weights/best.pt')
    frames=get_frames()
    for frame in frames:
        result,names =a.detect([frame])
        img=result[0][0] #第一张图片的处理结果图片
        '''
        for cls,(x1,y1,x2,y2),conf in result[0][1]: #第一张图片的处理结果标签。
            print(cls,x1,y1,x2,y2,conf)
            cv2.rectangle(img,(x1,y1),(x2,y2),(0,255,0))
            cv2.putText(img,names[cls],(x1,y1-20),cv2.FONT_HERSHEY_DUPLEX,1.5,(255,0,0))
        '''
        cv2.namedWindow("video",cv2.WINDOW_NORMAL)
        cv2.imshow("video",img)
        cv2.waitKey(1)
if __name__ == '__main__':
      main()

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

具体实现:(图片还是原博主的)
在这里插入图片描述

参考了yolov5 rtmp实时推理
一、下个OBS,捕获特定桌面窗口,功耗极低.

二、安装golang,只要命令行敲个命令运行go就行了。装完golang之后,git拉取livego,这东西是本地服务器,可以用OBS推流到服务器上,再从OBS上拉取rtmp视频流。
git地址:https://github.com/gwuhaolin/livego.git

livego使用步骤:
	1、转到 livego 目录并执行go build或make build
	2、双击exe文件运行livego
	3、获取串流密钥 http://localhost:8090/control/get?room=movie
	4、推流地址 rtmp://localhost:1935/live
	5、拉取播放地址 rtmp://localhost:1935/live/movie


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

三、验证一下是否获取到窗口rtmp视频流,OBS自定义推流到livego的推流地址rtmp://localhost:1935/live,随便用个播放器找到网络播放输入livego的播放地址rtmp://localhost:1935/live/movie,就能看到你的窗口rtmp视频流了。

四、yolov5推理指令–source后输入livego的播放地址rtmp://localhost:1935/live/movie,后面在跟一个–view-img,就能实时推理某一特定窗口了。(虽然最终没有用官方的detect了。)

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号