当前位置:   article > 正文

YOLOv7训练自己的数据集_raise exception(f'error: {p} does not exist') exce

raise exception(f'error: {p} does not exist') exception: error: c:\users9054

目录

1、制作YOLO格式数据集

1.1、数据集

1.2、如何转换为YOLOv7所需的格式?

1.3、如何批量化生成YOLO格式的txt标注

1.4、如何划分YOLO的train、val和test

2、使用YOLOv7训练自己的模型

2.1、测试预训练的yolov7.pt

(1)测试图片

(2)测试本地摄像头

(3)测试视频流效果

2.2、训练自己数据的YOLOv7模型

2.3、测试自己训练的模型

2.4、测试关键点检测


YOLOv7下载地址:YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors


1、制作YOLO格式数据集

1.1、数据集

本文采用的是EDS数据集:包含了来自 3 台不同 X 光机器的 14219 张图片, 其中 10 类物品, 共计 31655 个目标实例,均由专业标注人员进行标注。

每一台机器对应一个数据集,分别对应domain1、domain2和domain3,下图对应每个数据集中的类别分布且相对均匀。

代码显示部分图像:

  1. import matplotlib.pyplot as plt
  2. import glob
  3. import cv2
  4. def show_multi_img(imgpath,num):
  5. """
  6. :param imgpath: 图像地址
  7. :param num: 输出图像的数量:eg:6*6,一幅图展示36张
  8. :return:
  9. """
  10. img_path = glob.glob(imgpath+"/*")
  11. plt.figure()
  12. for i in range(1,num*num+1):
  13. img = cv2.imread(img_path[i])
  14. title = img_path[i].split("\\")[1]
  15. plt.subplot(num,num,i)
  16. plt.imshow(img)
  17. plt.title(title,fontsize=6)
  18. plt.xticks([])
  19. plt.yticks([])
  20. plt.axis("on")
  21. plt.savefig("final.png")
  22. plt.show()
  23. if __name__ == "__main__":
  24. image_dir = "./domain2/image"
  25. show_multi_img(image_dir,6)

每个domain分别由image和txt组成:

 1.2、如何转换为YOLOv7所需的格式?

首先来看一下yolo数据的标注:

 EDS数据集格式:

 假设图像的高度和宽度分别为H和W,bbox的左上角坐标为(xmin,ymin),右下角坐标为(xmax,ymax),则中心点(x_center,y_center),即

x_center = xmin + (xmax - xmin)/2

y_center = ymin + (ymax - ymin)/2

W = xmax - xmin

H = ymax - ymin

则YOLO数据格式为:label, x_, y_, w_, h_,则有对应关系:

x_ = x_center / img_width

y_ = y_center / img_height

w_ = W / img_width

h_ = H / img_height

其中label对应的是数字,需要将EDS中的类名转换为数字表示img_widthimg_height为图像的原始的宽度和高度,可以通过cv2.imread()读取,然后shape获取宽度和高度

  1. img= cv2.imread("./domain/image/00001.jpg")
  2. img_height,img_width,_ = img.shape

显示一幅图像并将bbox绘制在原图中:

  1. import cv2
  2. f = open("./domain1/txt/00004.txt",encoding="utf-8")
  3. img = cv2.imread('./domain1/image/00004.jpg')
  4. img_height,img_width,_ = img.shape
  5. for line in f.readlines():
  6. text = str(line.split(" ")[1])
  7. xmin = float(line.split(" ")[2])
  8. ymin = float(line.split(" ")[3])
  9. xmax = float(line.split(" ")[4])
  10. ymax = float(line.split(" ")[5])
  11. print("xmin:{},xmax:{},ymin:{},ymax:{}".format(xmin,xmax,ymin,ymax))
  12. x_center = xmin + (xmax - xmin) / 2
  13. y_center = ymin + (ymax - ymin) / 2
  14. w = xmax - xmin
  15. h = ymax - ymin
  16. # 保留6位小数
  17. x_center = round(x_center / img_width, 6)
  18. y_center = round(y_center / img_height, 6)
  19. w = round(w / img_width, 6)
  20. h = round(h / img_height, 6)
  21. # print(x_center,y_center,w,h)
  22. # 将yolo格式转换原始的格式进行验证
  23. x1 = int((float(x_center)-float(w)/2)*img_width)
  24. y1 = int((float(y_center) - float(h) / 2) * img_height)
  25. x2 = int((float(x_center) + float(w) / 2) * img_width)
  26. y2 = int((float(y_center) + float(h) / 2) * img_height)
  27. print(x1,y1,x2,y2)
  28. cv2.rectangle(img,(x1,y1),(x2,y2),(0,255,255),3)
  29. cv2.putText(img,text,(int(xmin),int(ymin)-5),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
  30. cv2.imshow("show",img)
  31. cv2.waitKey(0)
  32. cv2.imwrite("bbox.png",img)

转换前:xmin:84.0,ymin:369.0,xmax:342.0,ymax:554.0
转换后:xmin:83,ymin:368,xmax:341,ymax:553
转换前:xmin:210.0,ymin:409.0,xmax:591.0,ymax:691.0
转换后:xmin:210,ymin:409,xmax:591,ymax:691
转换前:xmin:182.0,ymin:457.0,xmax:364.0,ymax:550.0
转换后:xmin:181,ymin:456,xmax:364,ymax:549

-------------------------------------------------------------------------------------

这里还是存在一些转换的误差,不过影响没那么大。

注意:如果没有的标注数据可以用,可以下载LabelImg,进行YOLO格式的数据集标注,直接生存对应的yolo格式的数据集。

1.3、如何批量化生成YOLO格式的txt标注

  1. import glob
  2. import os
  3. import cv2
  4. txt_file = r".\domain1\txt"
  5. name = glob.glob(os.path.join(txt_file,"*.txt"))
  6. list_1 = []
  7. for i in name:
  8. f = open(i,encoding="utf-8")
  9. byt = f.readlines()
  10. for line in byt:
  11. list_1.append(line.split(" ")[1])
  12. x = line.split(" ")[2]
  13. y = line.split(" ")[3]
  14. w = line.split(" ")[4]
  15. h = line.split(" ")[5]
  16. # print(x,y,w,h)
  17. # 读取所有txt中的目标,并去重
  18. list2 = list(set(list_1))
  19. # print(list2)
  20. l = {} # EDS数据类名对应的数字
  21. j = 0
  22. for i in list2:
  23. l[i] = j
  24. j += 1
  25. print(l) #对应的字典形式
  26. # yolov7的第一列是cls_id x y w h 其中坐标(x,y)是中心点坐标,并且是相对于图片宽高的比例值 ,并非绝对坐标
  27. img_path = "./domain1/image"
  28. out_path = "./out"
  29. list_1 = []
  30. name = glob.glob(os.path.join(txt_file,"*.txt"))
  31. for i in name:
  32. if not os.path.exists(out_path):
  33. os.mkdir(out_path)
  34. with open(os.path.join(out_path, i.split("\\")[3].split(".")[0] + ".txt"), "w") as f_1:
  35. img_name = i.split("\\")[3].split(".")[0] + ".jpg"
  36. img = os.path.join(img_path,img_name)
  37. img_ = cv2.imread(img)
  38. img_height, img_width, _ = img_.shape
  39. f = open(i,encoding="utf-8")
  40. byt = f.readlines()
  41. for line in byt:
  42. class_num = l[line.split(" ")[1]]
  43. xmin = float(line.split(" ")[2])
  44. ymin = float(line.split(" ")[3])
  45. xmax = float(line.split(" ")[4])
  46. ymax = float(line.split(" ")[5])
  47. x_center = xmin + (xmax - xmin) / 2
  48. y_center = ymin + (ymax - ymin) / 2
  49. w = xmax - xmin
  50. h = ymax - ymin
  51. x_center = round(x_center / img_width, 6)
  52. y_center = round(y_center / img_height, 6)
  53. w = round(w / img_width, 6)
  54. h = round(h / img_height, 6)
  55. info = [str(i) for i in [class_num, x_center, y_center, w, h]]
  56. print(info)
  57. f_1.write(" ".join(info)+"\n")

1.4、如何划分YOLO的train、val和test

本文制作好的数据集:YOLO格式的EDS数据集,免费欢迎下载!感谢支持!

  1. # 将图片和标注数据按比例切分为 训练集和测试集
  2. import shutil
  3. import random
  4. import os
  5. # 原始路径,需要修改
  6. image_original_path = './domain1/image/'
  7. label_original_path = './out/'
  8. # 训练集路径,需要修改
  9. train_image_path = 'E:\yolov7\data\images\\train\\'
  10. train_label_path = 'E:\yolov7\data\labels\\train\\'
  11. # 验证集路径,需要修改
  12. val_image_path = 'E:\yolov7\data\images\\val\\'
  13. val_label_path = 'E:\yolov7\data\labels\\val\\'
  14. # 测试集路径,需要修改
  15. test_image_path = 'E:\yolov7\data\images\\test\\'
  16. test_label_path = 'E:\yolov7\data\labels\\test\\'
  17. # 数据集划分比例,训练集75%,验证集15%,测试集15%,按需修改
  18. train_percent = 0.7
  19. val_percent = 0.15
  20. test_percent = 0.1
  21. # 检查文件夹是否存在
  22. def mkdir():
  23. if not os.path.exists(train_image_path):
  24. os.makedirs(train_image_path)
  25. if not os.path.exists(train_label_path):
  26. os.makedirs(train_label_path)
  27. if not os.path.exists(val_image_path):
  28. os.makedirs(val_image_path)
  29. if not os.path.exists(val_label_path):
  30. os.makedirs(val_label_path)
  31. if not os.path.exists(test_image_path):
  32. os.makedirs(test_image_path)
  33. if not os.path.exists(test_label_path):
  34. os.makedirs(test_label_path)
  35. def main():
  36. mkdir()
  37. total_txt = os.listdir(label_original_path)
  38. num_txt = len(total_txt)
  39. list_all_txt = range(num_txt) # 范围 range(0, num)
  40. num_train = int(num_txt * train_percent)
  41. num_val = int(num_txt * val_percent)
  42. num_test = num_txt - num_train - num_val
  43. train = random.sample(list_all_txt, num_train)
  44. # train从list_all_txt取出num_train个元素
  45. # 所以list_all_txt列表只剩下了这些元素:val_test
  46. val_test = [i for i in list_all_txt if not i in train]
  47. # 再从val_test取出num_val个元素,val_test剩下的元素就是test
  48. val = random.sample(val_test, num_val)
  49. print("训练集数目:{}, 验证集数目:{},测试集数目:{}".format(len(train), len(val), len(val_test) - len(val)))
  50. for i in list_all_txt:
  51. name = total_txt[i][:-4]
  52. srcImage = image_original_path + name + '.jpg'
  53. srcLabel = label_original_path + name + '.txt'
  54. if i in train:
  55. dst_train_Image = train_image_path + name + '.jpg'
  56. dst_train_Label = train_label_path + name + '.txt'
  57. shutil.copyfile(srcImage, dst_train_Image)
  58. shutil.copyfile(srcLabel, dst_train_Label)
  59. elif i in val:
  60. dst_val_Image = val_image_path + name + '.jpg'
  61. dst_val_Label = val_label_path + name + '.txt'
  62. shutil.copyfile(srcImage, dst_val_Image)
  63. shutil.copyfile(srcLabel, dst_val_Label)
  64. else:
  65. dst_test_Image = test_image_path + name + '.jpg'
  66. dst_test_Label = test_label_path + name + '.txt'
  67. shutil.copyfile(srcImage, dst_test_Image)
  68. shutil.copyfile(srcLabel, dst_test_Label)
  69. if __name__ == '__main__':
  70. main()

2、使用YOLOv7训练自己的模型

官方地址:https://github.com/wongkinyiu/yolov7

采用git拉取:

git clone https://github.com/wongkinyiu/yolov7

2.1、测试预训练的yolov7.pt

官网提供了下载链接,可以直接下载,或者直接从csdn里下载:YOLOv7预训练权重

预训练权重下载完成后,打开detect.py

直接运行即可,其他都选择默认的参数!

(1)测试图片

或者修改--source为自己的图像路径,同样也可以修改--weights=your_weight_path,测试自己训练的模型

测试yolov7.pt的识别效果 

 

看下官方提供处理图像的代码:utils.datasets

  1. class LoadImages: # for inference
  2. def __init__(self, path, img_size=640, stride=32):
  3. """
  4. path:图像路径
  5. img_size:最终要测试的图像尺寸
  6. stride:这个主要用于pad一些小的图像以满足实际测试图像的尺寸
  7. return:
  8. path:图像的路径
  9. img:resize后的图像
  10. img0:原始图像
  11. self.cap
  12. """
  13. # 遍历输入的测试图像路径,files保存测试的地址
  14. p = str(Path(path).absolute()) # os-agnostic absolute path
  15. if '*' in p:
  16. files = sorted(glob.glob(p, recursive=True)) # glob
  17. elif os.path.isdir(p):
  18. files = sorted(glob.glob(os.path.join(p, '*.*'))) # dir
  19. elif os.path.isfile(p):
  20. files = [p] # files
  21. else:
  22. raise Exception(f'ERROR: {p} does not exist')
  23. # 通过图像地址的后缀判断是图像还是视频,然后用list格式保存
  24. images = [x for x in files if x.split('.')[-1].lower() in img_formats]
  25. videos = [x for x in files if x.split('.')[-1].lower() in vid_formats]
  26. # 文件里总共有多少和多少视频
  27. ni, nv = len(images), len(videos)
  28. self.img_size = img_size
  29. self.stride = stride
  30. self.files = images + videos # list格式
  31. self.nf = ni + nv # number of files
  32. self.video_flag = [False] * ni + [True] * nv # 用于判断是不是视频
  33. self.mode = 'image'
  34. if any(videos): # 判断videos是否存在
  35. self.new_video(videos[0]) # new video
  36. else:
  37. self.cap = None
  38. assert self.nf > 0, f'No images or videos found in {p}. ' \
  39. f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}'
  40. # __iter__迭代器,系统定义的名字
  41. def __iter__(self):
  42. self.count = 0
  43. return self
  44. def __next__(self):
  45. if self.count == self.nf:
  46. raise StopIteration
  47. path = self.files[self.count]
  48. if self.video_flag[self.count]:
  49. # Read video
  50. self.mode = 'video'
  51. ret_val, img0 = self.cap.read()
  52. if not ret_val:
  53. self.count += 1
  54. self.cap.release()
  55. if self.count == self.nf: # last video
  56. raise StopIteration
  57. else:
  58. path = self.files[self.count]
  59. self.new_video(path)
  60. ret_val, img0 = self.cap.read()
  61. self.frame += 1
  62. print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='')
  63. else:
  64. # Read image
  65. self.count += 1
  66. img0 = cv2.imread(path) # BGR
  67. assert img0 is not None, 'Image Not Found ' + path
  68. #print(f'image {self.count}/{self.nf} {path}: ', end='')
  69. # Padded resize
  70. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  71. # Convert
  72. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  73. img = np.ascontiguousarray(img)
  74. return path, img, img0, self.cap
  75. def new_video(self, path):
  76. self.frame = 0
  77. self.cap = cv2.VideoCapture(path)
  78. self.nframes = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
  79. def __len__(self):
  80. return self.nf # number of files

(2)测试本地摄像头

简单的一个获取本地摄像头的代码

  1. import cv2
  2. def access_camera(url,output_path):
  3. if url == int(0):
  4. cap = cv2.VideoCapture(url)
  5. else:
  6. cap = cv2.VideoCapture(url)
  7. while(cap.isOpened()):
  8. # Capture frame-by-frame
  9. ret, frame = cap.read()
  10. # Display the resulting frame
  11. cv2.imshow('frame',frame)
  12. cv2.imwrite(output_path,frame)
  13. print("图像保存成功!")
  14. if cv2.waitKey(1) & 0xFF == ord('q'):
  15. break
  16. # When everything done, release the capture
  17. cap.release()
  18. cv2.destroyAllWindows()
  19. if __name__ == "__main__":
  20. url = 'http://admin:admin@192.168.1.3:8081/video' # 调用IP摄像机
  21. output_path = "./runs/detect/img.png"
  22. # url = 0 # 调用笔记本摄像头
  23. access_camera(url,output_path)

yolov7提供的代码,其实思路是一样的

  1. class LoadWebcam: # for inference
  2. def __init__(self, pipe='0', img_size=640, stride=32):
  3. """
  4. pipe:0表示使用本地摄像头
  5. img_size:图像大小
  6. stride:
  7. """
  8. self.img_size = img_size
  9. self.stride = stride
  10. if pipe.isnumeric():
  11. pipe = eval(pipe) # local camera
  12. # pipe = 'rtsp://192.168.1.64/1' # IP camera
  13. # pipe = 'rtsp://username:password@192.168.1.64/1' # IP camera with login
  14. # pipe = 'http://wmccpinetop.axiscam.net/mjpg/video.mjpg' # IP golf camera
  15. self.pipe = pipe
  16. self.cap = cv2.VideoCapture(pipe) # video capture object
  17. self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
  18. def __iter__(self):
  19. self.count = -1
  20. return self
  21. def __next__(self):
  22. self.count += 1
  23. if cv2.waitKey(1) == ord('q'): # q to quit
  24. self.cap.release()
  25. cv2.destroyAllWindows()
  26. raise StopIteration
  27. # Read frame
  28. if self.pipe == 0: # local camera
  29. ret_val, img0 = self.cap.read()
  30. img0 = cv2.flip(img0, 1) # flip left-right
  31. else: # IP camera
  32. n = 0
  33. while True:
  34. n += 1
  35. self.cap.grab()
  36. if n % 30 == 0: # skip frames
  37. ret_val, img0 = self.cap.retrieve()
  38. if ret_val:
  39. break
  40. # Print
  41. assert ret_val, f'Camera Error {self.pipe}'
  42. img_path = 'webcam.jpg'
  43. print(f'webcam {self.count}: ', end='')
  44. # Padded resize
  45. img = letterbox(img0, self.img_size, stride=self.stride)[0]
  46. # Convert
  47. img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
  48. img = np.ascontiguousarray(img)
  49. return img_path, img, img0, None
  50. def __len__(self):
  51. return 0

(3)测试视频流效果

  1. class LoadStreams: # multiple IP or RTSP cameras
  2. def __init__(self, sources='streams.txt', img_size=640, stride=32):
  3. self.mode = 'stream'
  4. self.img_size = img_size
  5. self.stride = stride
  6. if os.path.isfile(sources):
  7. with open(sources, 'r') as f:
  8. sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
  9. else:
  10. sources = [sources]
  11. n = len(sources)
  12. self.imgs = [None] * n
  13. self.sources = [clean_str(x) for x in sources] # clean source names for later
  14. for i, s in enumerate(sources):
  15. # Start the thread to read frames from the video stream
  16. print(f'{i + 1}/{n}: {s}... ', end='')
  17. url = eval(s) if s.isnumeric() else s
  18. if 'youtube.com/' in str(url) or 'youtu.be/' in str(url): # if source is YouTube video
  19. check_requirements(('pafy', 'youtube_dl'))
  20. import pafy
  21. url = pafy.new(url).getbest(preftype="mp4").url
  22. cap = cv2.VideoCapture(url)
  23. assert cap.isOpened(), f'Failed to open {s}'
  24. w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  25. h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  26. self.fps = cap.get(cv2.CAP_PROP_FPS) % 100
  27. _, self.imgs[i] = cap.read() # guarantee first frame
  28. thread = Thread(target=self.update, args=([i, cap]), daemon=True)
  29. print(f' success ({w}x{h} at {self.fps:.2f} FPS).')
  30. thread.start()
  31. print('') # newline
  32. # check for common shapes
  33. s = np.stack([letterbox(x, self.img_size, stride=self.stride)[0].shape for x in self.imgs], 0) # shapes
  34. self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
  35. if not self.rect:
  36. print('WARNING: Different stream shapes detected. For optimal performance supply similarly-shaped streams.')
  37. def update(self, index, cap):
  38. # Read next stream frame in a daemon thread
  39. n = 0
  40. while cap.isOpened():
  41. n += 1
  42. # _, self.imgs[index] = cap.read()
  43. cap.grab()
  44. if n == 4: # read every 4th frame
  45. success, im = cap.retrieve()
  46. self.imgs[index] = im if success else self.imgs[index] * 0
  47. n = 0
  48. time.sleep(1 / self.fps) # wait time
  49. def __iter__(self):
  50. self.count = -1
  51. return self
  52. def __next__(self):
  53. self.count += 1
  54. img0 = self.imgs.copy()
  55. if cv2.waitKey(1) == ord('q'): # q to quit
  56. cv2.destroyAllWindows()
  57. raise StopIteration
  58. # Letterbox
  59. img = [letterbox(x, self.img_size, auto=self.rect, stride=self.stride)[0] for x in img0]
  60. # Stack
  61. img = np.stack(img, 0)
  62. # Convert
  63. img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416
  64. img = np.ascontiguousarray(img)
  65. return self.sources, img, img0, None
  66. def __len__(self):
  67. return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years

获取到所有的图像或者视频流,然后将获取的图像输入对应的model中,查看一下官方提供的detect.py代码

  1. def detect(save_img=False):
  2. source, weights, view_img, save_txt, imgsz, trace = opt.source, opt.weights, opt.view_img, opt.save_txt, opt.img_size, not opt.no_trace
  3. save_img = not opt.nosave and not source.endswith('.txt') # save inference images
  4. webcam = source.isnumeric() or source.endswith('.txt') or source.lower().startswith(
  5. ('rtsp://', 'rtmp://', 'http://', 'https://'))
  6. # Directories
  7. save_dir = Path(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) # increment run
  8. (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
  9. # Initialize
  10. set_logging()
  11. device = select_device(opt.device)
  12. half = device.type != 'cpu' # half precision only supported on CUDA
  13. # Load model
  14. model = attempt_load(weights, map_location=device) # load FP32 model
  15. stride = int(model.stride.max()) # model stride
  16. imgsz = check_img_size(imgsz, s=stride) # check img_size
  17. if trace:
  18. model = TracedModel(model, device, opt.img_size)
  19. if half:
  20. model.half() # to FP16
  21. # Second-stage classifier
  22. classify = False
  23. if classify:
  24. modelc = load_classifier(name='resnet101', n=2) # initialize
  25. modelc.load_state_dict(torch.load('weights/resnet101.pt', map_location=device)['model']).to(device).eval()
  26. # Set Dataloader
  27. vid_path, vid_writer = None, None
  28. if webcam:
  29. view_img = check_imshow()
  30. cudnn.benchmark = True # set True to speed up constant image size inference
  31. dataset = LoadStreams(source, img_size=imgsz, stride=stride)
  32. else:
  33. dataset = LoadImages(source, img_size=imgsz, stride=stride)
  34. # Get names and colors
  35. names = model.module.names if hasattr(model, 'module') else model.names
  36. colors = [[random.randint(0, 255) for _ in range(3)] for _ in names]
  37. # Run inference
  38. if device.type != 'cpu':
  39. model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
  40. t0 = time.time()
  41. for path, img, im0s, vid_cap in dataset:
  42. img = torch.from_numpy(img).to(device)
  43. img = img.half() if half else img.float() # uint8 to fp16/32
  44. img /= 255.0 # 0 - 255 to 0.0 - 1.0
  45. if img.ndimension() == 3:
  46. img = img.unsqueeze(0)
  47. # Inference
  48. t1 = time_synchronized()
  49. pred = model(img, augment=opt.augment)[0]
  50. # Apply NMS
  51. pred = non_max_suppression(pred, opt.conf_thres, opt.iou_thres, classes=opt.classes, agnostic=opt.agnostic_nms)
  52. t2 = time_synchronized()
  53. # Apply Classifier
  54. if classify:
  55. pred = apply_classifier(pred, modelc, img, im0s)
  56. # Process detections
  57. for i, det in enumerate(pred): # detections per image
  58. if webcam: # batch_size >= 1
  59. p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count
  60. else:
  61. p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0)
  62. p = Path(p) # to Path
  63. save_path = str(save_dir / p.name) # img.jpg
  64. txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
  65. s += '%gx%g ' % img.shape[2:] # print string
  66. gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
  67. if len(det):
  68. # Rescale boxes from img_size to im0 size
  69. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
  70. # Print results
  71. for c in det[:, -1].unique():
  72. n = (det[:, -1] == c).sum() # detections per class
  73. s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string
  74. # Write results
  75. for *xyxy, conf, cls in reversed(det):
  76. if save_txt: # Write to file
  77. xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
  78. line = (cls, *xywh, conf) if opt.save_conf else (cls, *xywh) # label format
  79. with open(txt_path + '.txt', 'a') as f:
  80. f.write(('%g ' * len(line)).rstrip() % line + '\n')
  81. if save_img or view_img: # Add bbox to image
  82. label = f'{names[int(cls)]} {conf:.2f}'
  83. plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
  84. # Print time (inference + NMS)
  85. #print(f'{s}Done. ({t2 - t1:.3f}s)')
  86. # Stream results
  87. if view_img:
  88. cv2.imshow(str(p), im0)
  89. cv2.waitKey(1) # 1 millisecond
  90. # Save results (image with detections)
  91. if save_img:
  92. if dataset.mode == 'image':
  93. cv2.imwrite(save_path, im0)
  94. print(f" The image with the result is saved in: {save_path}")
  95. else: # 'video' or 'stream'
  96. if vid_path != save_path: # new video
  97. vid_path = save_path
  98. if isinstance(vid_writer, cv2.VideoWriter):
  99. vid_writer.release() # release previous video writer
  100. if vid_cap: # video
  101. fps = vid_cap.get(cv2.CAP_PROP_FPS)
  102. w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  103. h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  104. else: # stream
  105. fps, w, h = 30, im0.shape[1], im0.shape[0]
  106. save_path += '.mp4'
  107. vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
  108. vid_writer.write(im0)
  109. if save_txt or save_img:
  110. s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
  111. #print(f"Results saved to {save_dir}{s}")
  112. print(f'Done. ({time.time() - t0:.3f}s)')

2.2、训练自己数据的YOLOv7模型

根据前面的方法制作自己数据集,放在yolov7/data目录下

 在yolov7/data目录下创建一个dataset.yaml文件,可以参考官方提供的coco.yaml

 自己的配置文件dataset.yaml

  1. train: E:/yolov7/data/images/train # train images
  2. val: E:/yolov7/data/images/val # val images
  3. test: E:/yolov7/data/images/test # test images (optional)
  4. # Classes
  5. nc: 10 # number of classes
  6. names: ['laptop','pressure','device','plasticbottle','scissor','knife','lighter','powerbank','glassbottle','umbrella'] # class names

开始训练。。。漫长的等待了,最终所有的训练信息都保存在yolov7/runs/train/exp目录下

2.3、测试自己训练的模型

修改detect.py中的weights地址 ,这个模型我只训练了10次,效果也还行。

2.4、测试关键点检测

首先下载官方提供的预训练模型yolov7-w6-pose.pt

  1. import matplotlib
  2. """
  3. 未使用matplotlib.use('TkAgg')
  4. 出现问题:UserWarning: Matplotlib is currently using agg, which is a non-GUI backend
  5. """
  6. matplotlib.use('TkAgg')
  7. import matplotlib.pyplot as plt
  8. print(matplotlib.get_backend())
  9. import torch
  10. import cv2
  11. from torchvision import transforms
  12. import numpy as np
  13. from utils.datasets import letterbox
  14. # 查看GUI backbend环境,主要是查看每个调用的代码环境下是否都是相同的环境
  15. print(matplotlib.get_backend())
  16. from utils.general import non_max_suppression_kpt
  17. print(matplotlib.get_backend())
  18. from utils.plots import output_to_keypoint, plot_skeleton_kpts
  19. # plots中matplotlib.get_backend()设置不同,记得修改
  20. print(matplotlib.get_backend())
  21. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  22. weigths = torch.load('../yolov7-w6-pose.pt')
  23. model = weigths['model']
  24. model = model.half().to(device)
  25. _ = model.eval()
  26. image = cv2.imread('../person.jpeg')
  27. image = letterbox(image, 960, stride=64, auto=True)[0]
  28. image_ = image.copy()
  29. image = transforms.ToTensor()(image)
  30. image = torch.tensor(np.array([image.numpy()]))
  31. image = image.to(device)
  32. image = image.half()
  33. output, _ = model(image)
  34. output = non_max_suppression_kpt(output, 0.25, 0.65, nc=model.yaml['nc'], nkpt=model.yaml['nkpt'], kpt_label=True)
  35. output = output_to_keypoint(output)
  36. nimg = image[0].permute(1, 2, 0) * 255
  37. nimg = nimg.cpu().numpy().astype(np.uint8)
  38. nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)
  39. for idx in range(output.shape[0]):
  40. plot_skeleton_kpts(nimg, output[idx, 7:].T, 3)
  41. plt.figure(figsize=(8,8))
  42. plt.axis('off')
  43. plt.imshow(nimg)
  44. plt.savefig("person_detection.png")
  45. plt.show()

 报错信息:

 在utils.plots.py中的442-443行中的增加detach()不在进行反向传播即可

 未完待续。。。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/295577
推荐阅读
相关标签
  

闽ICP备14008679号