赞
踩
前言:此文是我从yolov5替换到yolox训练的过程,前提是我们有图片和标注文件,而且都是yolov5的txt格式的;之前在网上看了一圈,怎么用自己的数据训练yolox模型,都是需要把标注文件整理成voc格式或coco数据集格式,连文件夹的存放方式都必须一样,真是麻烦;而我之前的任务都是基于yolov5训练的,所以图片,标注文件已经有了,我也不想按voc,coco那样再去改变格式,于是就有了此文。
yolov5数据集目录如下:
利用yolov5的txt格式的标注文件生成xml格式的标注文件,在生成的时候需注意:
1、yolov5的标注是经过归一化的c_x, c_y, w, h
2、背景图片yolov5可以不用标注,即没有对应的txt文件,但yolox训练却不行
3、图片名字不要带有空格,yolov5可以正常训练验证,但yolox在验证的时候会报错。
直接上生成xml的代码,文件名yolotxt2xml.py:
#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2021/09/14 11:14 # @Author : lishanlu # @File : yolotxt2xml.py # @Software: PyCharm # @Discription: from __future__ import absolute_import, print_function, division import os from xml.dom.minidom import Document import xml.etree.ElementTree as ET import cv2 ''' import xml xml.dom.minidom.Document().writexml() def writexml(self, writer: Any, indent: str = "", addindent: str = "", newl: str = "", encoding: Any = None) -> None ''' class YOLO2VOCConvert: def __init__(self, txts_path, xmls_path, imgs_path, classes_str_list): self.txts_path = txts_path # 标注的yolo格式标签文件路径 self.xmls_path = xmls_path # 转化为voc格式标签之后保存路径 self.imgs_path = imgs_path # 读取读片的路径个图片名字,存储到xml标签文件中 self.classes = classes_str_list # 类别列表 # 从所有的txt文件中提取出所有的类别, yolo格式的标签格式类别为数字 0,1,... # writer为True时,把提取的类别保存到'./Annotations/classes.txt'文件中 def search_all_classes(self, writer=False): # 读取每一个txt标签文件,取出每个目标的标注信息 all_names = set() txts = os.listdir(self.txts_path) # 使用列表生成式过滤出只有后缀名为txt的标签文件 txts = [txt for txt in txts if txt.split('.')[-1] == 'txt'] txts = [txt for txt in txts if not txt.split('.')[0] == "classes"] # 过滤掉classes.txt文件 print(len(txts), txts) # 11 ['0002030.txt', '0002031.txt', ... '0002039.txt', '0002040.txt'] for txt in txts: txt_file = os.path.join(self.txts_path, txt) with open(txt_file, 'r') as f: objects = f.readlines() for object in objects: object = object.strip().split(' ') print(object) # ['2', '0.506667', '0.553333', '0.490667', '0.658667'] all_names.add(int(object[0])) # print(objects) # ['2 0.506667 0.553333 0.490667 0.658667\n', '0 0.496000 0.285333 0.133333 0.096000\n', '8 0.501333 0.412000 0.074667 0.237333\n'] print("所有的类别标签:", all_names, "共标注数据集:%d张" % len(txts)) # 把从xmls标签文件中提取的类别写入到'./Annotations/classes.txt'文件中 # if writer: # with open('./Annotations/classes.txt', 'w') as f: # for label in all_names: # f.write(label + '\n') return list(all_names) def yolo2voc(self): """ 可以转换图片和txtlabel数量不匹配的情况,即有些图片是背景 :return: """ # 创建一个保存xml标签文件的文件夹 if not os.path.exists(self.xmls_path): os.makedirs(self.xmls_path) for img_name in os.listdir(self.imgs_path): # 读取图片的尺度信息 print("读取图片:", img_name) try: img = cv2.imread(os.path.join(self.imgs_path, img_name)) height_img, width_img, depth_img = img.shape print(height_img, width_img, depth_img) # h 就是多少行(对应图片的高度), w就是多少列(对应图片的宽度) except Exception as e: print("%s read fail, %s"%(img_name, e)) continue txt_name = img_name.replace(os.path.splitext(img_name)[1], '.txt') txt_file = os.path.join(self.txts_path, txt_name) all_objects = [] if os.path.exists(txt_file): with open(txt_file, 'r') as f: objects = f.readlines() for object in objects: object = object.strip().split(' ') all_objects.append(object) print(object) # ['2', '0.506667', '0.553333', '0.490667', '0.658667'] # 创建xml标签文件中的标签 xmlBuilder = Document() # 创建annotation标签,也是根标签 annotation = xmlBuilder.createElement("annotation") # 给标签annotation添加一个子标签 xmlBuilder.appendChild(annotation) # 创建子标签folder folder = xmlBuilder.createElement("folder") # 给子标签folder中存入内容,folder标签中的内容是存放图片的文件夹,例如:JPEGImages folderContent = xmlBuilder.createTextNode(self.imgs_path.split('/')[-1]) # 标签内存 folder.appendChild(folderContent) # 把内容存入标签 annotation.appendChild(folder) # 把存好内容的folder标签放到 annotation根标签下 # 创建子标签filename filename = xmlBuilder.createElement("filename") # 给子标签filename中存入内容,filename标签中的内容是图片的名字,例如:000250.jpg filenameContent = xmlBuilder.createTextNode(txt_name.split('.')[0] + '.jpg') # 标签内容 filename.appendChild(filenameContent) annotation.appendChild(filename) # 把图片的shape存入xml标签中 size = xmlBuilder.createElement("size") # 给size标签创建子标签width width = xmlBuilder.createElement("width") # size子标签width widthContent = xmlBuilder.createTextNode(str(width_img)) width.appendChild(widthContent) size.appendChild(width) # 把width添加为size的子标签 # 给size标签创建子标签height height = xmlBuilder.createElement("height") # size子标签height heightContent = xmlBuilder.createTextNode(str(height_img)) # xml标签中存入的内容都是字符串 height.appendChild(heightContent) size.appendChild(height) # 把width添加为size的子标签 # 给size标签创建子标签depth depth = xmlBuilder.createElement("depth") # size子标签width depthContent = xmlBuilder.createTextNode(str(depth_img)) depth.appendChild(depthContent) size.appendChild(depth) # 把width添加为size的子标签 annotation.appendChild(size) # 把size添加为annotation的子标签 # 每一个object中存储的都是['2', '0.506667', '0.553333', '0.490667', '0.658667']一个标注目标 for object_info in all_objects: # 开始创建标注目标的label信息的标签 object = xmlBuilder.createElement("object") # 创建object标签 # 创建label类别标签 # 创建name标签 imgName = xmlBuilder.createElement("name") # 创建name标签 imgNameContent = xmlBuilder.createTextNode(self.classes[int(object_info[0])]) imgName.appendChild(imgNameContent) object.appendChild(imgName) # 把name添加为object的子标签 # 创建pose标签 pose = xmlBuilder.createElement("pose") poseContent = xmlBuilder.createTextNode("Unspecified") pose.appendChild(poseContent) object.appendChild(pose) # 把pose添加为object的标签 # 创建truncated标签 truncated = xmlBuilder.createElement("truncated") truncatedContent = xmlBuilder.createTextNode("0") truncated.appendChild(truncatedContent) object.appendChild(truncated) # 创建difficult标签 difficult = xmlBuilder.createElement("difficult") difficultContent = xmlBuilder.createTextNode("0") difficult.appendChild(difficultContent) object.appendChild(difficult) # 先转换一下坐标 # (objx_center, objy_center, obj_width, obj_height)->(xmin,ymin, xmax,ymax) x_center = float(object_info[1]) * width_img + 1 y_center = float(object_info[2]) * height_img + 1 xminVal = int( x_center - 0.5 * float(object_info[3]) * width_img) # object_info列表中的元素都是字符串类型 yminVal = int(y_center - 0.5 * float(object_info[4]) * height_img) xmaxVal = int(x_center + 0.5 * float(object_info[3]) * width_img) ymaxVal = int(y_center + 0.5 * float(object_info[4]) * height_img) # 创建bndbox标签(三级标签) bndbox = xmlBuilder.createElement("bndbox") # 在bndbox标签下再创建四个子标签(xmin,ymin, xmax,ymax) 即标注物体的坐标和宽高信息 # 在voc格式中,标注信息:左上角坐标(xmin, ymin) (xmax, ymax)右下角坐标 # 1、创建xmin标签 xmin = xmlBuilder.createElement("xmin") # 创建xmin标签(四级标签) xminContent = xmlBuilder.createTextNode(str(xminVal)) xmin.appendChild(xminContent) bndbox.appendChild(xmin) # 2、创建ymin标签 ymin = xmlBuilder.createElement("ymin") # 创建ymin标签(四级标签) yminContent = xmlBuilder.createTextNode(str(yminVal)) ymin.appendChild(yminContent) bndbox.appendChild(ymin) # 3、创建xmax标签 xmax = xmlBuilder.createElement("xmax") # 创建xmax标签(四级标签) xmaxContent = xmlBuilder.createTextNode(str(xmaxVal)) xmax.appendChild(xmaxContent) bndbox.appendChild(xmax) # 4、创建ymax标签 ymax = xmlBuilder.createElement("ymax") # 创建ymax标签(四级标签) ymaxContent = xmlBuilder.createTextNode(str(ymaxVal)) ymax.appendChild(ymaxContent) bndbox.appendChild(ymax) object.appendChild(bndbox) annotation.appendChild(object) # 把object添加为annotation的子标签 f = open(os.path.join(self.xmls_path, txt_name.split('.')[0] + '.xml'), 'w') xmlBuilder.writexml(f, indent='\t', newl='\n', addindent='\t', encoding='utf-8') f.close() if __name__ == '__main__': imgs_path1 = 'F:/Dataset/road/images/val' # ['train', 'val'] txts_path1 = 'F:/Dataset/road/labels/val' # ['train', 'val'] xmls_path1 = 'F:/Dataset/road/xmls/val' # ['train', 'val'] classes_str_list = ['road_crack','road_sag'] # class name yolo2voc_obj1 = YOLO2VOCConvert(txts_path1, xmls_path1, imgs_path1, classes_str_list) labels = yolo2voc_obj1.search_all_classes() print('labels: ', labels) yolo2voc_obj1.yolo2voc()
将train和val都转换生成后,目录格式如下:
整个YOLOX的工程,训练过程,要想有一个大概浏览,可以见我的另一篇文章yolox训练解析
进入到YOLOX主目录
在yolox/data/datasets/目录下定义了数据的读取方式,有按coco方式读取,有按voc方式读取,另外mosaic增强也定义在这个文件夹下,我们添加新的读取方式就在这个目录下添加,添加yolo_style.py文件,代码如下:
#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2021/12/23 9:13 # @Author : lishanlu # @File : yolo_style.py # @Software: PyCharm # @Discription: 读入yolox风格的xmls数据 from __future__ import absolute_import, print_function, division import os import os.path import pickle import xml.etree.ElementTree as ET import cv2 import numpy as np from yolox.evaluators.voc_eval import voc_eval from .datasets_wrapper import Dataset from pathlib import Path import glob from tqdm import tqdm from PIL import Image, ExifTags import torch class AnnotationTransform(object): """Transforms a annotation into a Tensor of bbox coords and label index Initilized with a dictionary lookup of classnames to indexes Arguments: classes_name: (str, str, ...): dictionary lookup of classnames -> indexes keep_difficult (bool, optional): keep difficult instances or not (default: False) height (int): height width (int): width """ def __init__(self, classes_name, keep_difficult=True): self.class_to_ind = dict(zip(classes_name, range(len(classes_name)))) self.keep_difficult = keep_difficult def __call__(self, target): """ Arguments: target (annotation) : the target annotation to be made usable will be an ET.Element Returns: a list containing lists of bounding boxes [bbox coords, class name] """ res = np.empty((0, 5)) for obj in target.iter("object"): difficult = obj.find("difficult") if difficult is not None: difficult = int(difficult.text) == 1 else: difficult = False if not self.keep_difficult and difficult: continue name = obj.find("name").text.strip() bbox = obj.find("bndbox") pts = ["xmin", "ymin", "xmax", "ymax"] bndbox = [] for i, pt in enumerate(pts): cur_pt = int(bbox.find(pt).text) - 1 # scale height or width # cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height bndbox.append(cur_pt) label_idx = self.class_to_ind[name] bndbox.append(label_idx) res = np.vstack((res, bndbox)) # [xmin, ymin, xmax, ymax, label_ind] # img_id = target.find('filename').text[:-4] width = int(target.find("size").find("width").text) height = int(target.find("size").find("height").text) img_info = (height, width) return res, img_info """ generation yolo style dataloader. """ img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp'] # acceptable image suffixes # Get orientation exif tag for orientation in ExifTags.TAGS.keys(): if ExifTags.TAGS[orientation] == 'Orientation': break def img2xml_paths(img_paths): # Define xml paths as a function of image paths sa, sb = os.sep + 'images' + os.sep, os.sep + 'xmls' + os.sep # /images/, /xmls/ substrings return ['xml'.join(x.replace(sa, sb, 1).rsplit(x.split('.')[-1], 1)) for x in img_paths] def get_hash(files): # Returns a single hash value of a list of files return sum(os.path.getsize(f) for f in files if os.path.isfile(f)) def exif_size(img): # Returns exif-corrected PIL size s = img.size # (width, height) try: rotation = dict(img._getexif().items())[orientation] if rotation == 6: # rotation 270 s = (s[1], s[0]) elif rotation == 8: # rotation 90 s = (s[1], s[0]) except: pass return s def xyxy2xywh(x): # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center y[:, 2] = x[:, 2] - x[:, 0] # width y[:, 3] = x[:, 3] - x[:, 1] # height return y def segments2boxes(segments): # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh) boxes = [] for s in segments: x, y = s.T # segment xy boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy return xyxy2xywh(np.array(boxes)) # cls, xywh class YOLODetection(Dataset): """ YOLO Style Detection Dataset Object (read label from yolo style XML) input is image, target is annotation Args: data_dir (string): filepath to data folder. classes (string, string, ....): class string names. image_set (string): imageset to use (eg. 'train', 'val', 'test') preproc (callable, optional): transformation to perform on the input image target_transform (callable, optional): transformation to perform on the target `annotation` (eg: take in caption string, return tensor of word indices) dataset_name (string, optional): which dataset to load (default: 'yolo_dataset') """ def __init__( self, data_dir, classes, image_sets=['train'], img_size=(416, 416), preproc=None, dataset_name="yolo_dataset", cache=False, ): super().__init__(img_size) self.root = data_dir self.image_set = image_sets self.img_size = img_size self.preproc = preproc self._classes = classes self.target_transform = AnnotationTransform(self._classes, keep_difficult=True) self.name = dataset_name for name in image_sets: rootpath = self.root image_dir = os.path.join(rootpath, 'images', name) self.image_files = [os.path.join(image_dir, image_name) for image_name in os.listdir(image_dir)] if name == 'val': self.val_ids = [os.path.splitext(image_name)[0] for image_name in os.listdir(image_dir)] with open(os.path.join(rootpath, name+'.txt'), 'w') as f: for id in self.val_ids: f.write(id+'\n') self.xml_files = img2xml_paths(self.image_files) # list, xml file path self.annotations = self._load_xml_annotations() self.imgs = None if cache: self._cache_images() def __len__(self): return len(self.image_files) def _load_xml_annotations(self): return [self.load_anno_from_ids(_ids) for _ids in range(len(self.xml_files))] def _cache_images(self): pass def load_anno_from_ids(self, index): xml_file = self.xml_files[index] target = ET.parse(xml_file).getroot() assert self.target_transform is not None res, img_info = self.target_transform(target) height, width = img_info r = min(self.img_size[0] / height, self.img_size[1] / width) res[:, :4] *= r resized_info = (int(height * r), int(width * r)) return (res, img_info, resized_info) def load_anno(self, index): return self.annotations[index][0] def load_resized_img(self, index): img = self.load_image(index) r = min(self.img_size[0] / img.shape[0], self.img_size[1] / img.shape[1]) resized_img = cv2.resize( img, (int(img.shape[1] * r), int(img.shape[0] * r)), interpolation=cv2.INTER_LINEAR, ).astype(np.uint8) return resized_img def load_image(self, index): img = cv2.imread(self.image_files[index], cv2.IMREAD_COLOR) assert img is not None return img def pull_item(self, index): """Returns the original image and target at an index for mixup Note: not using self.__getitem__(), as any transformations passed in could mess up this functionality. Argument: index (int): index of img to show Return: img, target """ if self.imgs is not None: target, img_info, resized_info = self.annotations[index] pad_img = self.imgs[index] img = pad_img[: resized_info[0], : resized_info[1], :].copy() else: img = self.load_resized_img(index) target, img_info, _ = self.annotations[index] return img, target, img_info, index @Dataset.mosaic_getitem def __getitem__(self, index): img, target, img_info, img_id = self.pull_item(index) # 此target坐标为(x,y,x,y,cls) ### show read image and label. # from PIL import Image,ImageDraw # from matplotlib import pyplot as plt # img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) # draw = ImageDraw.Draw(img) # for j in range(target.shape[0]): # name = int(target[j][4]) # left = int(target[j][0]) # top = int(target[j][1]) # right = int(target[j][2]) # bottom = int(target[j][3]) # draw.text((left+10, top+10), f'{name}', fill='blue') # draw.rectangle((left, top, right, bottom), outline='red', width=2) # plt.imshow(img) # plt.show() if self.preproc is not None: img, target = self.preproc(img, target, self.input_dim) # 此target坐标为(cls, cx,cy,w,h) # from PIL import Image,ImageDraw # from matplotlib import pyplot as plt # img = np.transpose(img.astype(np.uint8), (1, 2, 0)) # img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) # draw = ImageDraw.Draw(img) # for j in range(target.shape[0]): # name = int(target[j][0]) # left = int(target[j][1]-target[j][3]/2) # top = int(target[j][2]-target[j][4]/2) # right = int(target[j][1]+target[j][3]/2) # bottom = int(target[j][2]+target[j][4]/2) # draw.text((left+10, top+10), f'{name}', fill='blue') # draw.rectangle((left, top, right, bottom), outline='red', width=2) # plt.imshow(img) # plt.show() return img, target, img_info, img_id def evaluate_detections(self, all_boxes, output_dir=None): """ all_boxes is a list of length number-of-classes. Each list element is a list of length number-of-images. Each of those list elements is either an empty list [] or a numpy array of detection. all_boxes[class][image] = [] or np.array of shape #dets x 5 """ self._write_voc_results_file(all_boxes) IouTh = np.linspace(0.5, 0.95, int(np.round((0.95 - 0.5) / 0.05)) + 1, endpoint=True) mAPs = [] for iou in IouTh: mAP = self._do_python_eval(output_dir, iou) mAPs.append(mAP) print("--------------------------------------------------------------") print("map_5095:", np.mean(mAPs)) print("map_50:", mAPs[0]) print("--------------------------------------------------------------") return np.mean(mAPs), mAPs[0] def _get_voc_results_file_template(self): filename = "comp4_det_test" + "_{:s}.txt" filedir = os.path.join(self.root, "results") if not os.path.exists(filedir): os.makedirs(filedir) path = os.path.join(filedir, filename) return path def _write_voc_results_file(self, all_boxes): self.ids = [os.path.splitext(os.path.split(image_file)[1])[0] for image_file in self.image_files] for cls_ind, cls in enumerate(self._classes): cls_ind = cls_ind if cls == "__background__": continue print("Writing {} VOC results file".format(cls)) filename = self._get_voc_results_file_template().format(cls) with open(filename, "wt") as f: for im_ind, index in enumerate(self.ids): #index = index[1] dets = all_boxes[cls_ind][im_ind] if dets == []: continue for k in range(dets.shape[0]): f.write( "{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n".format( index, dets[k, -1], dets[k, 0] + 1, dets[k, 1] + 1, dets[k, 2] + 1, dets[k, 3] + 1, ) ) def _do_python_eval(self, output_dir="output", iou=0.5): rootpath = self.root name = self.image_set[0] annopath = os.path.join(rootpath, "xmls", "val", "{:s}.xml") imagesetfile = os.path.join(rootpath, name + ".txt") cachedir = os.path.join( self.root, "annotations_cache" ) if not os.path.exists(cachedir): os.makedirs(cachedir) aps = [] # The PASCAL VOC metric changed in 2010 # use_07_metric = True if int(self._year) < 2010 else False use_07_metric = True print("Eval IoU : {:.2f}".format(iou)) if output_dir is not None and not os.path.isdir(output_dir): os.mkdir(output_dir) for i, cls in enumerate(self._classes): if cls == "__background__": continue filename = self._get_voc_results_file_template().format(cls) rec, prec, ap = voc_eval( filename, annopath, imagesetfile, cls, cachedir, ovthresh=iou, use_07_metric=use_07_metric, ) aps += [ap] if iou == 0.5: print("AP for {} = {:.4f}".format(cls, ap)) if output_dir is not None: with open(os.path.join(output_dir, cls + "_pr.pkl"), "wb") as f: pickle.dump({"rec": rec, "prec": prec, "ap": ap}, f) if iou == 0.5: print("Mean AP = {:.4f}".format(np.mean(aps))) print("~~~~~~~~") print("Results:") for ap in aps: print("{:.3f}".format(ap)) print("{:.3f}".format(np.mean(aps))) print("~~~~~~~~") print("") print("--------------------------------------------------------------") print("Results computed with the **unofficial** Python eval code.") print("Results should be very close to the official MATLAB eval code.") print("Recompute with `./tools/reval.py --matlab ...` for your paper.") print("-- Thanks, The Management") print("--------------------------------------------------------------") return np.mean(aps)
定义好这个文件,别忘了在yolox/data/datasets/的__init__.py文件中加入from .yolo_style import YOLODetection
在exps/example/目录下新建一个任务目录,比如road,在这个目录下新建文件yolox_road.py,这个文件用于定义训练用的类Exp,它继承自yolox/exp/下的yolox_base.py中的Exp类,主要定义模型参数,数据集参数及数据增强参数,创建dataloader等函数。代码示例如下:
#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2021/12/23 8:58 # @Author : lishanlu # @File : yolox_road.py # @Software: PyCharm # @Discription: from __future__ import absolute_import, print_function, division import os import torch import torch.nn as nn import torch.distributed as dist from yolox.data import get_yolox_datadir from yolox.exp import Exp as MyExp class Exp(MyExp): def __init__(self): super(Exp, self).__init__() # ------------ model config -------------------# self.num_classes = 2 # 修改为和自己的数据类别一致 self.depth = 0.67 self.width = 0.75 # ---------------- dataloader config ---------------- # # set worker to 4 for shorter dataloader init time self.data_num_workers = 4 self.input_size = (640, 640) # (height, width) # Actual multiscale ranges: [640-5*32, 640+5*32]. # To disable multiscale training, set the # self.multiscale_range to 0. self.multiscale_range = 5 # You can uncomment this line to specify a multiscale range # self.random_size = (14, 26) self.data_dir = 'your data rootdir' # 指定数据的根目录 self.classes_name = ('class1','class2') # 指定类别名字 self.dataset_name = 'yolo_dataset' # 数据库名字,可以不用修改 # --------------- transform config ----------------- # self.mosaic_prob = 1.0 self.mixup_prob = 1.0 self.hsv_prob = 1.0 self.flip_prob = 0.5 self.degrees = 5.0 self.translate = 0.1 self.mosaic_scale = (0.5, 1.5) self.mixup_scale = (0.5, 1.5) self.shear = 2.0 self.perspective = 0.0 self.enable_mixup = False # -------------- training config --------------------- # self.warmup_epochs = 5 self.max_epoch = 300 self.warmup_lr = 0 self.basic_lr_per_img = 0.01 / 64.0 self.scheduler = "yoloxwarmcos" self.milestones = [70, 120, 180, 300] # 该参数只用于multi_step学习率衰减 self.gamma = 0.1 # 该参数只用于multi_step学习率衰减 self.no_aug_epochs = 300 self.min_lr_ratio = 0.05 self.ema = True self.weight_decay = 5e-4 self.momentum = 0.9 self.print_interval = 10 self.eval_interval = 1 self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] # ----------------- testing config ------------------ # self.test_size = (640, 640) self.test_conf = 0.01 self.nmsthre = 0.65 def get_model(self): from yolox.models import YOLOX, YOLOPAFPN, YOLOXHead def init_yolo(M): for m in M.modules(): if isinstance(m, nn.BatchNorm2d): m.eps = 1e-3 m.momentum = 0.03 if getattr(self, "model", None) is None: in_channels = [256, 512, 1024] backbone = YOLOPAFPN(self.depth, self.width, in_channels=in_channels) head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels) # strides=[8,16,32], in_channels=in_channels self.model = YOLOX(backbone, head) self.model.apply(init_yolo) self.model.head.initialize_biases(1e-2) return self.model def get_data_loader(self, batch_size, is_distributed, no_aug=False, cache_img=False): from yolox.data import ( YOLODetection, TrainTransform, YoloBatchSampler, DataLoader, InfiniteSampler, MosaicDetection, worker_init_reset_seed, ) from yolox.utils import ( wait_for_the_master, get_local_rank, ) local_rank = get_local_rank() with wait_for_the_master(local_rank): dataset = YOLODetection(data_dir=self.data_dir, classes=self.classes_name, image_sets=['train'], img_size=self.input_size, preproc=TrainTransform( max_labels=50, flip_prob=self.flip_prob, hsv_prob=self.hsv_prob), dataset_name=self.dataset_name, cache=cache_img) dataset = MosaicDetection( dataset, mosaic=not no_aug, img_size=self.input_size, preproc=TrainTransform( max_labels=120, flip_prob=self.flip_prob, hsv_prob=self.hsv_prob), degrees=self.degrees, translate=self.translate, mosaic_scale=self.mosaic_scale, mixup_scale=self.mixup_scale, shear=self.shear, perspective=self.perspective, enable_mixup=self.enable_mixup, mosaic_prob=self.mosaic_prob, mixup_prob=self.mixup_prob, ) # import pdb;pdb.set_trace() self.dataset = dataset if is_distributed: batch_size = batch_size // dist.get_world_size() sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0) batch_sampler = YoloBatchSampler( sampler=sampler, batch_size=batch_size, drop_last=False, mosaic=not no_aug, ) dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True} dataloader_kwargs["batch_sampler"] = batch_sampler dataloader_kwargs["worker_init_fn"] = worker_init_reset_seed train_loader = DataLoader(self.dataset, **dataloader_kwargs) return train_loader def get_eval_loader(self, batch_size, is_distributed, testdev=False, legacy=False): from yolox.data import YOLODetection, ValTransform valdataset = YOLODetection( data_dir=self.data_dir, classes=self.classes_name, image_sets=['val'], img_size=self.test_size, preproc=ValTransform(legacy=legacy), dataset_name=self.dataset_name ) if is_distributed: batch_size = batch_size // dist.get_world_size() sampler = torch.utils.data.distributed.DistributedSampler( valdataset, shuffle=False ) else: sampler = torch.utils.data.SequentialSampler(valdataset) dataloader_kwargs = { "num_workers": self.data_num_workers, "pin_memory": True, "sampler": sampler, } dataloader_kwargs["batch_size"] = batch_size val_loader = torch.utils.data.DataLoader(valdataset, **dataloader_kwargs) return val_loader def get_evaluator(self, batch_size, is_distributed, testdev=False, legacy=False): from yolox.evaluators import VOCEvaluator val_loader = self.get_eval_loader(batch_size, is_distributed, testdev, legacy) evaluator = VOCEvaluator( dataloader=val_loader, img_size=self.test_size, confthre=self.test_conf, nmsthre=self.nmsthre, num_classes=self.num_classes, ) return evaluator def get_lr_scheduler(self, lr, iters_per_epoch, **kwargs): from yolox.utils import LRScheduler scheduler = LRScheduler( self.scheduler, lr, iters_per_epoch, self.max_epoch, warmup_epochs=self.warmup_epochs, warmup_lr_start=self.warmup_lr, no_aug_epochs=self.no_aug_epochs, min_lr_ratio=self.min_lr_ratio, **kwargs ) return scheduler
写一个sh文件train.sh,代码如下:
python tools/train.py \
--experiment-name yolox_road \
--batch-size 48 \
--devices 0 \
--exp_file exps/example/road/yolox_road.py \
--fp16 \
--ckpt pre_train/yolox_m.pth
运行命令bash ./train.sh就可以启动训练
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。