当前位置:   article > 正文

目标检测-YOLOv7代码及训练

yolov7代码

目录

论文及代码下载

代码结构

代码学习

1 train.py

1.1 参数

1.2 Resume 训练中断后继续训练

1.3 Train 训练

2 yolov7.yaml

2.1 parameters and anchors

2.2 网络结构

3 detect.py

3.1 参数

3.2 绘制检测框 plot_one_box()

训练流程

1 数据集准备

1.0 数据集处理前的数据集结构

1.1 VOC格式转YOLO格式

1.2 划分训练集、验证集、测试集

1.3 生成yolov7需要的文件夹格式

1.4 创建最终训练需要的train.txt和val.txt文件

1.5 最终数据集结构

2 需要修改的参数

2.1 --weights

2.2 --cfg

2.3 --data

2.4 --hyp

2.5 --epochs

2.6 --batch-size

2.7 --device

2.8 --project

2.9 --workers


论文及代码下载

论文:https://arxiv.org/abs/2207.02696

代码:https://github.com/WongKinYiu/yolov7

在实际使用过程中,我用的YOLOAir库,和YOLOv7源码使用方法是一样的

yoloair2:https://github.com/iscyy/yoloair2

代码结构

  1. .
  2. ├── cfg(存放yaml文件)
  3. │ ├── baseline
  4. │ ├── deploy
  5. │ ├── training
  6. │ └── improved
  7. ├── configs(也存放了yaml文件,但经过了改进)
  8. │ ├── attention
  9. │ ├── attention_v7
  10. │ ├── backbone
  11. │ ├── head-Improved
  12. │ └── ......
  13. ├── data(数据集信息和模型训练超参数)
  14. │ ├── coco.yaml
  15. │ ├── coco128.yaml
  16. │ ├── hyp.scratch.custom.yaml
  17. │ ├── hyp.scratch.p5.yaml
  18. │ ├── hyp.scratch.p6.yaml
  19. │ └── hyp.scratch.tiny.yaml
  20. ├── deploy
  21. │ └── triton-inference-server
  22. │ ├── data
  23. │ │ ├── dog.jpg
  24. │ │ └── dog_result.jpg
  25. │ ├── boundingbox.py
  26. │ ├── client.py
  27. │ ├── labels.py
  28. │ ├── processing.py
  29. │ ├── render.py
  30. │ └── README.md
  31. ├── figure(一些图片)
  32. ├── inference(存放detect用的图片)
  33. │ └── images
  34. ├── models(存放网络结构)
  35. │ ├── Models
  36. │ │ ├── Attention
  37. │ │ └── **.py
  38. │ ├── Detect
  39. │ ├── __init__.py
  40. │ ├── common.py
  41. │ ├── commonv5.py
  42. │ ├── experimental.py
  43. │ ├── module.py
  44. │ └── yolo.py
  45. ├── runs(模型训练时的输出)
  46. ├── scripts(获得COCO 数据集)
  47. │ └── get_coco.sh
  48. ├── utils
  49. │ ├── aws
  50. │ │ ├── __init__.py
  51. │ │ ├── mime.sh
  52. │ │ ├── resume.py
  53. │ │ └── userdata.sh
  54. │ ├── google_app_engine
  55. │ │ ├── additional_requirements.txt
  56. │ │ ├── app.yaml
  57. │ │ └── Dockerfile
  58. │ ├── wandb_logging
  59. │ │ ├── __init__.py
  60. │ │ ├── log_dataset.py
  61. │ │ └── wandb_utils.py
  62. │ ├── __init__.py
  63. │ ├── activations.py(激活函数)
  64. │ ├── add_nms.py
  65. │ ├── autoanchor.py
  66. │ ├── datasets.py(数据的读取和加载)
  67. │ ├── general.py
  68. │ ├── google_utils.py
  69. │ ├── loss.py(损失函数)
  70. │ ├── metrics.py(衡量指标)
  71. │ ├── plots.py(画图)
  72. │ └── torch_utils.py
  73. ├── train_aux.py(使用较大的预训练权重)
  74. ├── train.py
  75. ├── test.py(测试)
  76. ├── detect.py(检测)
  77. ├── export.py
  78. ├── hubconf.py
  79. ├── LICENSE.md
  80. ├── README.md
  81. ├── README_EN.md
  82. └── requirements.txt

代码学习

(仅自己训练过程中去了解过的内容)

1 train.py

1.1 参数

  1. if __name__ == '__main__':
  2. parser = argparse.ArgumentParser()
  3. parser.add_argument('--weights', type=str, default='weight\yolov7.pt', help='initial weights path') # 初始化权重文件,如果有预训练模型,可以直接在此加载
  4. parser.add_argument('--cfg', type=str, default='cfg/training\yolov7.yaml', help='model.yaml path') # 网络结构配置文件
  5. parser.add_argument('--data', type=str, default='data\WindTurbineBlades.yaml', help='data.yaml path') # 训练数据集配置文件
  6. parser.add_argument('--hyp', type=str, default='data/hyp.scratch.p5.yaml', help='hyperparameters path') # 超参数配置文件
  7. parser.add_argument('--epochs', type=int, default=100) # 训练迭代次数
  8. parser.add_argument('--batch-size', type=int, default=32, help='total batch size for all GPUs') # 训练批次大小
  9. parser.add_argument('--img-size', nargs='+', type=int, default=[640, 640], help='[train, test] image sizes') # 训练图片大小
  10. parser.add_argument('--rect', action='store_true', help='rectangular training') # 是否采用矩形训练,默认False
  11. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training') # 是否继续进行训练,如果设置成True,那么会自动寻找最近训练权重文件
  12. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint') # 不保存权重文件,默认False
  13. parser.add_argument('--notest', action='store_true', help='only test final epoch') # 不进行test,默认False
  14. parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check') # 不自动调整anchor,默认False
  15. parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters') # 是否进行超参数优化,默认是False,开启该选项,会加大训练时间,一般不需要
  16. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket') # 谷歌云盘bucket,一般不会用到
  17. parser.add_argument('--cache-images', action='store_true', help='cache images for faster training') # 是否提前将训练数据进行缓存,默认是False
  18. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') # 训练的时候是否选择图片权重进行训练
  19. parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') # 训练所使用的设备
  20. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') # 是否进行多尺度训练,默认False
  21. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') # 训练数据集是否只有一类
  22. parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') # 是否使用adam优化器
  23. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') # 是否使用跨卡同步BN,在DDP模式使用
  24. parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') # DDP参数
  25. parser.add_argument('--workers', type=int, default=0, help='maximum number of dataloader workers') # dataloader的最大worker数量
  26. parser.add_argument('--project', default='runs/train', help='save to project/name') # 训练结果保存路径
  27. parser.add_argument('--entity', default=None, help='W&B entity') # wandb库对应的东西,一般不用管
  28. parser.add_argument('--name', default='exp', help='save to project/name') # 训练结果保存文件夹名称
  29. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') # 判断下训练结果保存路径是否存在,如果存在的话,就不用重新创建
  30. parser.add_argument('--quad', action='store_true', help='quad dataloader') # 作用是兼顾速度和精度,选择折中的方案
  31. parser.add_argument('--linear-lr', action='store_true', help='linear LR') # 用于对学习速率进行调整,默认为 false,含义是通过余弦函数来降低学习率。
  32. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon') # 是否做标签平滑,防止出现过拟合
  33. parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table') # wandb库对应的东西
  34. parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B') # wandb库对应的东西
  35. parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch') # 用于记录训练日志信息
  36. parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used') # 这一行参数表达的是想实现但还未实现的一个内容,忽略即可
  37. parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone of yolov7=50, first3=0 1 2') # 冻结训练,默认不冻结
  38. opt = parser.parse_args()

1.2 Resume 训练中断后继续训练

  1. # Resume
  2. wandb_run = check_wandb_resume(opt)
  3. if opt.resume and not wandb_run: # resume an interrupted run
  4. ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
  5. assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
  6. apriori = opt.global_rank, opt.local_rank
  7. with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
  8. opt = argparse.Namespace(**yaml.load(f, Loader=yaml.SafeLoader)) # replace
  9. opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = '', ckpt, True, opt.total_batch_size, *apriori # reinstate
  10. logger.info('Resuming training from %s' % ckpt)
  11. else:
  12. # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
  13. opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files
  14. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  15. opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test)
  16. opt.name = 'evolve' if opt.evolve else opt.name
  17. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run

若训练过程中训练中断,将Resume改为True,程序会自动寻找最后一轮的权重,然后继续训练。

check_wandb_resume()检查训练是否恢复,wandb是一个用于深度学习实验跟踪、可视化和协作的工具和平台。

1.3 Train 训练

  1. # Train
  2. logger.info(opt)
  3. if not opt.evolve:
  4. tb_writer = None # init loggers
  5. if opt.global_rank in [-1, 0]:
  6. prefix = colorstr('tensorboard: ')
  7. logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
  8. tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
  9. train(hyp, opt, device, tb_writer)
  10. # Evolve hyperparameters (optional) # tc
  11. else:
  12. ......

判断是否进行超参数优化,一般都是False;else后面的是超参数优化的参数,这之后的代码全是关于超参数优化的

if opt.global_rank in [-1, 0]检查是否使用分布式训练,-1表示没有,global_rank默认为-1。该条件语句内容是对Tensorboard进行设置。

train()函数,开始深度学习模型的训练

2 yolov7.yaml

2.1 parameters and anchors

  1. # parameters
  2. nc: 80 # number of classes
  3. depth_multiple: 1.0 # model depth multiple
  4. width_multiple: 1.0 # layer channel multiple
  5. # anchors
  6. anchors:
  7. - [12,16, 19,36, 40,28] # P3/8
  8. - [36,75, 76,55, 72,146] # P4/16
  9. - [142,110, 192,243, 459,401] # P5/32

depth_multiple:控制模型深度,与number相乘(红色框位置),表示该层重复的次数

width_multiple:控制输出通道数,与args第一个参数相乘(黄色框位置)

2.2 网络结构

from:表示其输入来自哪一层。

-1表示该层输入是上一层的输出,-2表示该层输入是上上层的输出,以此类推。若为正数则为正数第几层。[-1, -3, -5, -6]为倒数第一层、倒数第三层、倒数第五层、倒数第六层的输出。

number:表示该层的层数

module:表示使用的模块

args:表示模块的参数

从左到右为:输出通道数,卷积核大小,卷积核步长

  1. # yolov7 backbone
  2. backbone:
  3. # [from, number, module, args]
  4. [[-1, 1, Conv, [32, 3, 1]], # 0
  5. [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
  6. [-1, 1, Conv, [64, 3, 1]],
  7. [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
  8. [-1, 1, Conv, [64, 1, 1]], # 下面表示的就是ELAN模块
  9. [-2, 1, Conv, [64, 1, 1]],
  10. [-1, 1, Conv, [64, 3, 1]],
  11. [-1, 1, Conv, [64, 3, 1]],
  12. [-1, 1, Conv, [64, 3, 1]],
  13. [-1, 1, Conv, [64, 3, 1]],
  14. [[-1, -3, -5, -6], 1, Concat, [1]],
  15. [-1, 1, Conv, [256, 1, 1]], # 11

3 detect.py

3.1 参数

  1. if __name__ == '__main__':
  2. parser = argparse.ArgumentParser()
  3. parser.add_argument('--weights', nargs='+', type=str, default='runs\train\exp\weights\best.pt', help='model.pt path(s)') # 训练好的权重
  4. parser.add_argument('--source', type=str, default='inference', help='source') # file/folder, 0 for webcam 测试的图片/图片文件夹/摄像头接口
  5. parser.add_argument('--img-size', type=int, default=640, help='inference size (pixels)') # 图片大小
  6. parser.add_argument('--conf-thres', type=float, default=0.25, help='object confidence threshold') # 置信度的阈值
  7. parser.add_argument('--iou-thres', type=float, default=0.45, help='IOU threshold for NMS') # iou阈值
  8. parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  9. parser.add_argument('--view-img', action='store_true', help='display results') # 是否展示测试结果
  10. parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') # 是否保存测试的标签文件
  11. parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') # 是否保存测试时目标的置信度,要和save-txt一起使用
  12. parser.add_argument('--nosave', action='store_true', help='do not save images/videos') # 是否保存测试图像
  13. parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') # 是否只检测特定的某一类或几类,如classes 0就只检测数据集中yolo标签为0的
  14. parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') # 增强版的nms
  15. parser.add_argument('--augment', action='store_true', help='augmented inference')
  16. parser.add_argument('--update', action='store_true', help='update all models')
  17. parser.add_argument('--project', default='runs\detect', help='save results to project/name') # 结果保存路径
  18. parser.add_argument('--name', default='exp', help='save results to project/name') # 保存的文件夹名字
  19. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  20. parser.add_argument('--no-trace', action='store_true', help='don`t trace model')

3.2 绘制检测框 plot_one_box()

该函数在utils/plots.py中

  1. def plot_one_box(x, img, color=None, label=None, line_thickness=3):
  2. # Plots one bounding box on image img
  3. # 计算了绘制线条的粗细
  4. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  5. color = color or [random.randint(0, 255) for _ in range(3)]
  6. # 计算了边界框的两个对角点的坐标c1和c2,将x中的坐标信息四舍五入为整数,用于在图像上绘制边界框
  7. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  8. # 使用OpenCV的cv2.rectangle函数在图像上绘制一个矩形边界框
  9. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  10. if label: # 是否提供了标签文本
  11. # 计算了标签文本的字体线条粗细
  12. tf = max(tl - 1, 1) # font thickness
  13. # cv2.getTextSize()函数能得到文字绘制出来将有多大
  14. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  15. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  16. # 绘制一个填充矩形,用于容纳标签文本。这样可以在标签文本的位置上创建一个背景,使文本更容易阅读
  17. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  18. # 将标签文本添加到图像上,参数为:图像,标签文本、文本框的左下角、字体、字体大小、颜色、字体线条粗细、线型
  19. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)

x:边界框坐标的列表,(x_min, y_min, x_max, y_max),表示左上角和右下角的坐标

img:要绘制边界框的图像

color:边界框和标签的颜色,默认是随机的

label:表示要添加到边界框上的标签文本

line_thickness:线条粗细,默认为3。设为None线条粗细就跟随图像大小变化

训练流程

1 数据集准备

本人标注得到的数据集是VOC格式的

1.0 数据集处理前的数据集结构

  1. VOC
  2. ├── Annotations 标注(XML文件)
  3. └── JPEGImages 图像

1.1 VOC格式转YOLO格式

(!!!要将‘.../VOC’改为自己的路径)

  1. '''第一步:将xml转化成txt'''
  2. import shutil
  3. import os.path
  4. import xml.etree.ElementTree as ET
  5. # 1. 将这个地方改成自己类别的列表
  6. class_names = ['**', '**', '**']
  7. # 2. 将路径修改
  8. xmldir='.../VOC/Annotations/'
  9. imagedir=".../VOC/JPEGImages/"
  10. txtdir='.../VOC/xml2txt/'
  11. xmlpath=os.path.join(xmldir)# 原xml路径
  12. imagepath=os.path.join(imagedir)# 原img路径
  13. txtpath=os.path.join(txtdir)# 转换后txt文件存放路径
  14. files = []
  15. if not os.path.exists(txtpath):
  16. os.makedirs(txtpath)
  17. else:
  18. shutil.rmtree(txtpath)
  19. os.makedirs(txtpath)
  20. image_list=os.listdir(imagepath)
  21. postfixes = set(['.' + i.split('.')[1] for i in image_list]) #考虑到可能有不同后缀的图片,提取所有后缀
  22. for root, dirs, files in os.walk(xmlpath): #用来
  23. None
  24. number = len(files)
  25. i = 0
  26. delete=[]
  27. while i < number:
  28. filename = files[i][0:-4]
  29. xml_name = filename + ".xml"
  30. txt_name = filename + ".txt"
  31. xml_file_name = os.path.join(xmlpath,xml_name)
  32. txt_file_name = os.path.join(txtpath,txt_name)
  33. try:
  34. xml_file = open(xml_file_name, 'r', encoding='utf-8') #注意加个utf8编码就是,不然会报gbk的错
  35. tree = ET.parse(xml_file)
  36. root = tree.getroot()
  37. # filename = root.find('name').text
  38. # image_name = root.find('filename').text
  39. w = int(root.find('size').find('width').text)
  40. h = int(root.find('size').find('height').text)
  41. f_txt = open(txt_file_name, 'w+')
  42. content = ""
  43. first = True
  44. for obj in root.iter('object'):
  45. name = obj.find('name').text
  46. class_num = class_names.index(name)
  47. # class_num = 0
  48. xmlbox = obj.find('bndbox')
  49. x1 = int(xmlbox.find('xmin').text)
  50. x2 = int(xmlbox.find('xmax').text)
  51. y1 = int(xmlbox.find('ymin').text)
  52. y2 = int(xmlbox.find('ymax').text)
  53. if first:
  54. content += str(class_num) + " " + \
  55. str((x1 + x2) / 2 / w) + " " + str((y1 + y2) / 2 / h) + " " + \
  56. str((x2 - x1) / w) + " " + str((y2 - y1) / h)
  57. first = False
  58. else:
  59. content += "\n" + \
  60. str(class_num) + " " + \
  61. str((x1 + x2) / 2 / w) + " " + str((y1 + y2) / 2 / h) + " " + \
  62. str((x2 - x1) / w) + " " + str((y2 - y1) / h)
  63. # print(str(i / (number - 1) * 100) + "%\n")
  64. print(content)
  65. f_txt.write(content)
  66. f_txt.close()
  67. xml_file.close()
  68. i += 1
  69. except ZeroDivisionError as zeroE:
  70. print(xml_name+'转化失败!')
  71. i += 1
  72. # #利用os库把xml_name文件复制到test文件夹下
  73. # shutil.move(xmldir + xml_name, 'test/' + xml_name)
  74. # shutil.move(imagedir + filename+'.jpg', 'test/' + filename+'.jpg')
  75. # 删除对应的label和images
  76. os.remove(os.path.join(xmlpath, xml_name))
  77. print(xml_name + '删除成功!')
  78. for postfix in postfixes:
  79. if (filename + postfix) in image_list:
  80. img_name = filename + postfix
  81. os.remove(os.path.join(imagepath, img_name))
  82. print(img_name + '删除成功!')
  83. delete.append((img_name, xml_name))
  84. i += 1
  85. print('总图片数量:',number)
  86. print(f'成功转化{number-len(delete)}张图片')
  87. print(f'共删除{len(delete)}张图片和标签:')
  88. for each in delete:
  89. print(each)

1.2 划分训练集、验证集、测试集

  1. '''第二步:划分train.txt,val.txt,test.txt
  2. 注意trainval.txt是作为辅助'''
  3. import os
  4. import random
  5. import shutil
  6. random.seed(0)
  7. # 1. 将路径修改为自己的
  8. xmlfilepath = '.../VOC/Annotations/'
  9. saveBasePath = '.../VOC/ImageSets/Main/'
  10. if not os.path.exists(saveBasePath):
  11. os.makedirs(saveBasePath)
  12. else:
  13. shutil.rmtree(saveBasePath)
  14. os.makedirs(saveBasePath)
  15. # ----------------------------------------------------------------------#
  16. # 想要增加测试集修改trainval_percent
  17. # train_percent不需要修改
  18. # ----------------------------------------------------------------------#
  19. trainval_percent = 0.9
  20. train_percent = 0.8
  21. temp_xml = os.listdir(xmlfilepath)
  22. total_xml = []
  23. for xml in temp_xml:
  24. if xml.endswith(".xml"):
  25. total_xml.append(xml)
  26. num = len(total_xml)
  27. list = range(num)
  28. tv = int(num * trainval_percent)
  29. tr = int(tv * train_percent)
  30. trainval = random.sample(list, tv)
  31. train = random.sample(trainval, tr)
  32. print("train and val size", tv)
  33. print("traub suze", tr)
  34. if not os.path.exists(saveBasePath):
  35. os.mkdir(saveBasePath)
  36. ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w')
  37. ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
  38. ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
  39. fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')
  40. for i in list:
  41. name = total_xml[i][:-4] + '\n'
  42. if i in trainval:
  43. ftrainval.write(name)
  44. if i in train:
  45. ftrain.write(name)
  46. else:
  47. fval.write(name)
  48. else:
  49. ftest.write(name)
  50. ftrainval.close()
  51. ftrain.close()
  52. fval.close()
  53. ftest.close()

1.3 生成yolov7需要的文件夹格式

  1. '''第三步:复制出图片和标签,生成yolov7需要的文件夹格式
  2. ----images
  3. --train
  4. --val
  5. --test(可选)
  6. ----labels
  7. --train
  8. --val
  9. --test(可选)'''
  10. import os
  11. import shutil
  12. from tqdm import tqdm
  13. SPLIT_PATH = '.../VOC/ImageSets/Main/'
  14. IMGS_PATH = ".../VOC/JPEGImages/"
  15. TXTS_PATH = ".../VOC/xml2txt/"
  16. TO_IMGS_PATH = '.../VOC/images'
  17. TO_TXTS_PATH = '.../VOC/labels'
  18. data_split = ['train.txt', 'val.txt', 'test.txt']
  19. to_split = ['train', 'val', 'test']
  20. image_list=os.listdir(IMGS_PATH)
  21. postfixes = set(['.' + i.split('.')[1] for i in image_list]) #考虑到可能有不同后缀的图片,提取所有后缀
  22. for index, split in enumerate(data_split):
  23. split_path = os.path.join(SPLIT_PATH, split)
  24. to_imgs_path = os.path.join(TO_IMGS_PATH, to_split[index])
  25. if not os.path.exists(to_imgs_path):
  26. os.makedirs(to_imgs_path)
  27. else:
  28. shutil.rmtree(to_imgs_path)
  29. os.makedirs(to_imgs_path)
  30. to_txts_path = os.path.join(TO_TXTS_PATH, to_split[index])
  31. if not os.path.exists(to_txts_path):
  32. os.makedirs(to_txts_path)
  33. else:
  34. shutil.rmtree(to_txts_path)
  35. os.makedirs(to_txts_path)
  36. f = open(split_path, 'r')
  37. count = 1
  38. for line in tqdm(f.readlines(), desc="{} is copying".format(to_split[index])):
  39. # 复制图片
  40. for postfix in postfixes:
  41. if (line.strip()+ postfix) in image_list:
  42. img_name=line.strip()+ postfix
  43. src_img_path = os.path.join(IMGS_PATH, img_name)
  44. dst_img_path = os.path.join(to_imgs_path,img_name)
  45. if os.path.exists(src_img_path):
  46. shutil.copyfile(src_img_path, dst_img_path)
  47. else:
  48. print("error file: {}".format(src_img_path))
  49. # 复制txt标注文件
  50. src_txt_path = os.path.join(TXTS_PATH, line.strip() + '.txt')
  51. dst_txt_path = os.path.join(to_txts_path, line.strip() + '.txt')
  52. if os.path.exists(src_txt_path):
  53. shutil.copyfile(src_txt_path, dst_txt_path)
  54. else:
  55. print("error file: {}".format(src_txt_path))

1.4 创建最终训练需要的train.txt和val.txt文件

  1. '''第四步:创建出最终训练需要的train.txt和val.txt文件 '''
  2. import os
  3. def listdir(path, list_name): # 传入存储的list
  4. for file in os.listdir(path):
  5. file_path = os.path.join(path, file)
  6. if os.path.isdir(file_path):
  7. listdir(file_path, list_name)
  8. else:
  9. list_name.append(file_path)
  10. list_name = []
  11. train_path = '.../VOC/images/train' # 文件夹路径,把images下的train/val/test(可选)文件夹处理就行 。注意这里开头的/不能少!!!
  12. listdir(train_path, list_name)
  13. print(list_name)
  14. with open('.../VOC/train.txt', 'w') as f: # 要存入的txt
  15. write = ''
  16. for i in list_name:
  17. write = write + str(i) + '\n'
  18. f.write(write)
  19. list_name = []
  20. val_path = '.../VOC/images/val' # 文件夹路径,把images下的train/val/test(可选)文件夹处理就行 。注意这里开头的/不能少!!!
  21. listdir(val_path, list_name)
  22. print(list_name)
  23. with open('.../VOC/val.txt', 'w') as f2: # 要存入的txt
  24. write = ''
  25. for i in list_name:
  26. write = write + str(i) + '\n'
  27. f2.write(write)

1.5 最终数据集结构

  1. VOC
  2. ├── Annotations 所有的图像标注信息(XML文件)
  3. ├── JPEGImages 所有图像文件
  4. ├── ImageSets
  5. │ └── Main
  6. │ ├── train.txt 训练集
  7. │ ├── val.txt 验证集
  8. │ ├── trainval.txt 训练集+验证集
  9. │ └── test.txt 测试集
  10. ├── images
  11. │ ├── train
  12. │ ├── val
  13. │ └── test
  14. ├── labels
  15. │ ├── train
  16. │ ├── val
  17. │ └── test
  18. ├── xml2txt YOLO格式标注
  19. ├── train.txt 用于训练的train.txt
  20. └── val.txt 用于训练的val.txt

2 需要修改的参数

2.1 --weights

根据自己所选的网络选择预训练权重,也可以为空(即不使用预训练模型)

2.2 --cfg

选择网络结构,同时要修改对应yaml文件中的内容

  1. # parameters
  2. nc: 80 # number of classes 这里要修改为自己的类别数
  3. depth_multiple: 1.0 # model depth multiple
  4. width_multiple: 1.0 # layer channel multiple
  5. # anchors 根据自己需要看要不要修改
  6. anchors:
  7. - [12,16, 19,36, 40,28] # P3/8
  8. - [36,75, 76,55, 72,146] # P4/16
  9. - [142,110, 192,243, 459,401] # P5/32

2.3 --data

数据集配置文件,将路径改为自己需要的,同时修改yaml文件中的内容

  1. # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
  2. # path: coco128 # dataset root dir
  3. train: .../VOC/train.txt # 数据集准备1.4中生成的
  4. val: .../VOC/val.txt # 数据集准备1.4中生成的
  5. # test: # test images (optional)
  6. # Classes
  7. nc: 80 # number of classes 修改为自己的类别数
  8. names: ['**', '**', '**'] # class names 修改为自己的类别,要和数据集准备1.1部分类别顺序一致

2.4 --hyp

网络超参数,路径改为自己需要的对应yaml文件

2.5 --epochs

2.6 --batch-size

2.7 --device

训练时用的设备,没有gpu就写cpu,有就写编号

2.8 --project

训练结果保存路径,一般不用改,要想放到别的文件夹下就改成自己的路径

2.9 --workers

是否多线程读取数据,越大cpu读取速度越快。

单张图像太大的时候可以填大一点,一般为0、2、4、8、16

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/295669
推荐阅读
相关标签
  

闽ICP备14008679号