赞
踩
utils/dataloaders.py
是用于加载数据并创建数据加载器的工具类
LoadImages
、LoadStreams
、LoadImagesAndLabels
、HUBDatasetStats
等。这些类和函数用于从文件系统加载图像和标签数据,并将其转换为模型可用的数据格式。InfiniteDataLoader
和ClassificationDataset
,以便将数据加载到模型中进行训练和验证。InfiniteDataLoader类是PyTorch框架中的DataLoader类的子类。DataLoader类是用于加载数据集的主要工具,它可以自动地把数据集划分成小批量(batch),并且可以在加载数据时使用多线程并行处理。
InfiniteDataLoader类重写了父类DataLoader的几个方法,实现了一个可以无限循环加载数据的数据加载器。具体来说,这个类的主要功能和工作方式如下:
__init__
方法:这是类的初始化方法。在这个方法里,首先调用了父类DataLoader的初始化方法,然后用一个名为_RepeatSampler的对象替换了DataLoader的batch_sampler属性,这样可以使数据加载器在取尽所有样本后可以再从头开始取样本。最后,创建了一个迭代器self.iterator,用于在后面的方法中生成数据。
__len__
方法:这个方法返回了数据集中的样本数量。
__iter__
方法:这个方法定义了数据生成的方式。它利用前面创建的迭代器self.iterator,在每次循环时生成一个新的数据批次。由于self.iterator使用了_RepeatSampler作为批次采样器,所以当所有的数据都被取尽时,它会自动地从头开始,从而实现数据的无限循环。
这类定义了一个用于加载图像和标签数据的数据集类LoadImagesAndLabels
,用于训练和验证YOLOv5模型。该类继承自torch.utils.data.Dataset
,并实现了__init__
、__len__
、__getitem__
等方法。
方法:
__len__(self)
:返回数据集的图像数量。__getitem__(self, index)
:获取指定索引的图像和标签。load_image(self, i)
:加载指定索引的图像。cache_images_to_disk(self, i)
:将图像缓存到磁盘中。load_mosaic(self, index)
:加载mosaic数据增强的图像和标签。load_mosaic9(self, index)
:加载9-mosaic数据增强的图像和标签。collate_fn(self, batch)
:将批次中的图像和标签进行组合。collate_fn4(self, batch)
:将4-mosaic批次中的图像和标签进行组合。create_dataloader
创建一个数据加载器(DataLoader),用于在训练深度学习模型时加载数据集。它接受一个数据集的路径和各种参数,然后返回一个可以用于批量加载和处理该数据集的数据加载器。
以下是其处理步骤:
torch_distributed_zero_first(rank)
确保数据集的 *.cache 文件只被初始化一次。LoadImagesAndLabels
对象,该对象负责加载和处理图像及其相关的标签。这个过程中,会根据提供的各种参数(如是否进行数据增强,是否使用矩形批处理,是否缓存图像等)来进行相应的处理。image_weights
为 True
,则使用 DataLoader
,否则使用 InfiniteDataLoader
。数据加载器负责按照设定的批处理大小和采样器,批量加载和处理数据。同时还使用了一些全局变量,如 LOGGER
(用于记录日志), RANK
和 PIN_MEMORY
(用于处理分布式数据并行)。
path
:数据集路径,可以是包含图像文件的文件夹路径,也可以是包含图像文件路径的列表。imgsz
:图像尺寸,用于将图像调整为统一尺寸。batch_size
:批大小,每个批次包含的图像数量。stride
:图像步长,用于计算特征图大小。single_cls
:是否进行单类别训练。hyp
:超参数字典。augment
:是否进行数据增强。cache
:是否将图像缓存到内存或磁盘中加快训练速度。pad
:图像填充比例。rect
:是否使用矩形训练。rank
:当前进程的排名。workers
:数据加载器的工作线程数。image_weights
:是否使用图像权重。quad
:是否使用四通道图像。loader
:数据加载器对象。dataset
:数据集对象。loader, dataset = create_dataloader(path='data/images', imgsz=640, batch_size=16, stride=32)
class InfiniteDataLoader(dataloader.DataLoader): # 定义一个名为InfiniteDataLoader的类,该类继承自dataloader.DataLoader
def __init__(self, *args, **kwargs): # 定义构造函数,接受任意数量的位置参数和关键字参数
super().__init__(*args, **kwargs) # 调用父类的构造函数,传入同样的参数
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) # 设置self.batch_sampler属性为_RepeatSampler对象,传入self.batch_sampler作为参数
self.iterator = super().__iter__() # 设置self.iterator为父类的迭代器
def __len__(self): # 定义长度函数,返回batch_sampler.sampler的长度
return len(self.batch_sampler.sampler)
def __iter__(self): # 定义迭代器函数
for _ in range(len(self)): # 迭代self的长度次数
yield next(self.iterator) # 每次迭代返回self.iterator的下一个元素
# 定义一个名为create_dataloader的函数,功能是创建一个数据加载器,该函数接受多个参数,包括数据路径、图像大小、批处理大小、步幅等
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0, rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix='', shuffle=False, seed=0):
if rect and shuffle: # 如果rect为True且shuffle为True
LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False') # 输出警告信息
shuffle = False # 将shuffle设置为False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = LoadImagesAndLabels( # 创建LoadImagesAndLabels对象,传入参数
path, # 数据集路径
imgsz, # 图像尺寸
batch_size, # batch大小
augment=augment, # 是否进行数据增强
hyp=hyp, # 超参数
rect=rect, # 是否使用矩形batch
cache_images=cache, # 是否缓存图像
single_cls=single_cls, # 是否进行单类别训练
stride=int(stride), # 步长
pad=pad, # 填充
image_weights=image_weights, # 是否使用图像权重
prefix=prefix) # 前缀
# 确定批处理大小,不能超过数据集长度
batch_size = min(batch_size, len(dataset))
# 获取CUDA设备数量
nd = torch.cuda.device_count()
# 计算工作进程数量,不能超过CPU核数,批处理大小,以及设置的工作进程数
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
# 如果是分布式环境,则创建一个分布式采样器,否则设置为None
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
# 如果使用image_weights,则使用DataLoader加载数据,否则使用InfiniteDataLoader加载数据,这两者的区别在于DataLoader允许对属性进行更新
loader = DataLoader if image_weights else InfiniteDataLoader
generator = torch.Generator() # 随机数生成器
generator.manual_seed(6148914691236517205 + seed + RANK) # 设置随机数种子
return loader(dataset, # 返回DataLoader或InfiniteDataLoader对象
batch_size=batch_size, # batch大小
shuffle=shuffle and sampler is None, # 是否进行shuffle
num_workers=nw, # worker数量
sampler=sampler, # 采样器
pin_memory=PIN_MEMORY, # 是否将数据存储在固定内存中
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn, # 数据集合并函数
worker_init_fn=seed_worker, # worker初始化函数
generator=generator), dataset # 返回DataLoader或InfiniteDataLoader对象和数据集对象
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。