赞
踩
项目地址:https://github.com/argusswift/YOLOv4-pytorch
这份代码实现的逻辑非常清楚,主要一些数据集处理的代码需要相应的改动:
这里的数据集label格式:
train_annotation:
- image_name1 x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id
- image_name2 x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id
改动内容:
(1)路径
- DATA_PATH = ""
- PROJECT_PATH = ""
- DETECTION_PATH=""
-
- MODEL_TYPE = {
- "TYPE": "YOLOv4"
- }
(2)训练参数
- TRAIN = {
- "DATA_TYPE": "Customer", # DATA_TYPE: VOC ,COCO or Customer
- "TRAIN_IMG_SIZE": 608,
- "AUGMENT": True,
- "BATCH_SIZE": 16,
- "MULTI_SCALE_TRAIN": False,
- "IOU_THRESHOLD_LOSS": 0.5,
- "YOLO_EPOCHS": 50,
- "Mobilenet_YOLO_EPOCHS": 120,
- "NUMBER_WORKERS": 0,
- "MOMENTUM": 0.9,
- "WEIGHT_DECAY": 0.0005,
- "LR_INIT": 1e-4,
- "LR_END": 1e-6,
- "WARMUP_EPOCHS": 2, # or None
- }
(3)VAL 参数
- VAL = {
- "TEST_IMG_SIZE": 608, #同train
- "BATCH_SIZE": 1,
- "NUMBER_WORKERS": 0,
- "CONF_THRESH": 0.005,
- "NMS_THRESH": 0.45,
- "MULTI_SCALE_VAL": False, #-----
- "FLIP_VAL": False, #------ 因为数据集里有信号灯,所以关闭翻转
- "Visual": True,
- }
(4)Customer_DATA 目标列表
- Customer_DATA = {
- "NUM": **, # your dataset number
- "CLASSES": [** ], # your dataset class
- }
(1)设 img_path
(2)如果训练中报resize的错,注意检查训练数据集,可能是由于:
如果是公开数据集可能不会出这种错,如果是自己做的数据集,标注过程可能会出现这种。
(3)index 越界
这个问题是因为index越界,是datasets.py中
这个位置出现xind,或者yind的越界,比如网络输入是608大小,到第一级anchor层的stride是8,这一层的特征图大小就是608/8=76。
所以xind或者yind的取值范围应该在[0-75],报错是因为这里xind或yind取到了76,做一些越界判断处理即可,例如设定xind/yind上限为75。
由于自己的label格式和文件夹并不是按照voc的格式,所以这里的evaluator.py和voc_eval.py都需要进行相应的修改。
(1)self.val_data_path
(2)img_inds_file
(3)img_path
(4)annopath 和 imagesetfile (存放的是图像名字列表)配置val数据集的label路径和图像路径
(1)parse_gt 函数,因为我没有用voc的xml格式,所以在解析label的时候自己重写了这个函数
相应的,下面调用的时候:
(2)下面都是针对数据集label不是voc的xml格式带来的不同需要修改的地方
这样修改完,基本就可以训练跑起来了:
训练命令:
CUDA_VISIBLE_DEVICES=0 nohup python -u train.py --weight_path weight/yolov4.weights --gpu_id 0
CUDA_VISIBLE_DEVICES=0 python3 video_test.py --weight_path ./weight/best.pt --gpu_id 0 --video_path video.mp4 --output_dir .
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。