当前位置:   article > 正文

(Pytorch) YOLOV4 : 训练自己的数据集【左侧有码】_pytorch版yolov4训练自己的数据集

pytorch版yolov4训练自己的数据集

 

项目地址:https://github.com/argusswift/YOLOv4-pytorch

这份代码实现的逻辑非常清楚,主要一些数据集处理的代码需要相应的改动:

这里的数据集label格式:

train_annotation:

  1. image_name1 x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id
  2. image_name2 x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id x1,y1,x2,y2,class_id

1,yolov4_config.py

改动内容:

(1)路径

  1. DATA_PATH = ""
  2. PROJECT_PATH = ""
  3. DETECTION_PATH=""
  4. MODEL_TYPE = {
  5. "TYPE": "YOLOv4"
  6. }

(2)训练参数

  1. TRAIN = {
  2. "DATA_TYPE": "Customer", # DATA_TYPE: VOC ,COCO or Customer
  3. "TRAIN_IMG_SIZE": 608,
  4. "AUGMENT": True,
  5. "BATCH_SIZE": 16,
  6. "MULTI_SCALE_TRAIN": False,
  7. "IOU_THRESHOLD_LOSS": 0.5,
  8. "YOLO_EPOCHS": 50,
  9. "Mobilenet_YOLO_EPOCHS": 120,
  10. "NUMBER_WORKERS": 0,
  11. "MOMENTUM": 0.9,
  12. "WEIGHT_DECAY": 0.0005,
  13. "LR_INIT": 1e-4,
  14. "LR_END": 1e-6,
  15. "WARMUP_EPOCHS": 2, # or None
  16. }

(3)VAL 参数

  1. VAL = {
  2. "TEST_IMG_SIZE": 608, #同train
  3. "BATCH_SIZE": 1,
  4. "NUMBER_WORKERS": 0,
  5. "CONF_THRESH": 0.005,
  6. "NMS_THRESH": 0.45,
  7. "MULTI_SCALE_VAL": False, #-----
  8. "FLIP_VAL": False, #------ 因为数据集里有信号灯,所以关闭翻转
  9. "Visual": True,
  10. }

(4)Customer_DATA 目标列表

  1. Customer_DATA = {
  2. "NUM": **, # your dataset number
  3. "CLASSES": [** ], # your dataset class
  4. }

2,/utils/datasets.py

(1)设 img_path

(2)如果训练中报resize的错,注意检查训练数据集,可能是由于:

  • 空行
  • 只有图片名字没有box
  • x2-x1 or y2-y1 的box宽/高为负值。

如果是公开数据集可能不会出这种错,如果是自己做的数据集,标注过程可能会出现这种。

(3)index 越界

这个问题是因为index越界,是datasets.py中

这个位置出现xind,或者yind的越界,比如网络输入是608大小,到第一级anchor层的stride是8,这一层的特征图大小就是608/8=76。

所以xind或者yind的取值范围应该在[0-75],报错是因为这里xind或yind取到了76,做一些越界判断处理即可,例如设定xind/yind上限为75。

 

3, /eval/evaluator.py + voc_eval.py

由于自己的label格式和文件夹并不是按照voc的格式,所以这里的evaluator.py和voc_eval.py都需要进行相应的修改。

evaluator.py: 都按照自己实际的修改

(1)self.val_data_path

(2)img_inds_file

(3)img_path

(4)annopath 和 imagesetfile (存放的是图像名字列表)配置val数据集的label路径和图像路径

voc_eval.py

(1)parse_gt 函数,因为我没有用voc的xml格式,所以在解析label的时候自己重写了这个函数

相应的,下面调用的时候:

 

(2)下面都是针对数据集label不是voc的xml格式带来的不同需要修改的地方

 

这样修改完,基本就可以训练跑起来了:

训练命令:

CUDA_VISIBLE_DEVICES=0 nohup python -u train.py  --weight_path weight/yolov4.weights --gpu_id 0

 

 

CUDA_VISIBLE_DEVICES=0 python3 video_test.py --weight_path ./weight/best.pt --gpu_id 0 --video_path video.mp4 --output_dir .

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/木道寻08/article/detail/874987
推荐阅读
相关标签
  

闽ICP备14008679号