当前位置:   article > 正文

pytorch-分类-检测-分割的dataset和dataloader创建

pytorch-分类-检测-分割的dataset和dataloader创建

1.前言

        在PyTorch中,DatasetDataLoader是两个重要的工具,用于构建输入数据的管道。

(1)Dataset是一个抽象类,表示数据集,需要实现__len____getitem__方法。

(2)DataLoader是一个可迭代的数据加载器,它封装了数据集的加载、批处理、打乱和并行加载等功能。

2.分类任务创建DatasetDataLoader

        (1)对于分类任务,Dataset需要返回图像和对应的标签

  1. from torch.utils.data import Dataset
  2. from PIL import Image
  3. import os
  4. import torch
  5. class ClassificationDataset(Dataset):
  6. def __init__(self, root_dir, transform=None):
  7. self.transform = transform
  8. self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
  9. self.labels = [...] # 这里应该是与图像对应的标签列表
  10. def __len__(self):
  11. return len(self.images)
  12. def __getitem__(self, idx):
  13. img_path = self.images[idx]
  14. image = Image.open(img_path).convert('RGB')
  15. label = self.labels[idx]
  16. if self.transform:
  17. image = self.transform(image)
  18. return image, label

        (2)DataLoader加载数据

  1. from torch.utils.data import DataLoader
  2. transform = ... # 这里定义你的数据预处理流程
  3. dataset = ClassificationDataset(root_dir='path_to_your_data', transform=transform)
  4. dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

3.检测任务创建DatasetDataLoader

        (1)Dataset需要返回图像和对应的边界框信息

  1. class DetectionDataset(Dataset):
  2. def __init__(self, root_dir, transform=None):
  3. self.transform = transform
  4. self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
  5. self.annotations = [...] # 这里应该是与图像对应的边界框信息列表
  6. def __len__(self):
  7. return len(self.images)
  8. def __getitem__(self, idx):
  9. img_path = self.images[idx]
  10. image = Image.open(img_path).convert('RGB')
  11. boxes = self.annotations[idx] # 这些是边界框信息
  12. if self.transform:
  13. image, boxes = self.transform(image, boxes)
  14. return image, boxes

 (2)DataLoader加载数据

dataloader = DataLoader(DetectionDataset(root_dir='path_to_your_data', transform=transform), batch_size=2, shuffle=True)

4.分割任务创建DatasetDataLoader

(1)Dataset需要返回图像和对应的分割掩码

  1. class SegmentationDataset(Dataset):
  2. def __init__(self, root_dir, transform=None):
  3. self.transform = transform
  4. self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
  5. self.masks = [...] # 这里应该是与图像对应的分割掩码列表
  6. def __len__(self):
  7. return len(self.images)
  8. def __getitem__(self, idx):
  9. img_path = self.images[idx]
  10. mask_path = self.masks[idx]
  11. image = Image.open(img_path).convert('RGB')
  12. mask = Image.open(mask_path).convert('L') # 假设掩码是灰度图
  13. if self.transform:
  14. image, mask = self.transform(image, mask)
  15. return image, mask

(2)DataLoader加载数据

dataloader = DataLoader(SegmentationDataset(root_dir='path_to_your_data', transform=transform), batch_size=4, shuffle=True)

在PyTorch的DatasetDataLoader框架中,idx(或称为索引)是通过迭代DataLoader时自动生成的。当你创建一个DataLoader实例,并在训练循环中迭代它时,DataLoader会内部调用Dataset__getitem__方法,并自动为你提供索引idx

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/360471?site
推荐阅读
相关标签
  

闽ICP备14008679号