赞
踩
在PyTorch中,Dataset
和DataLoader
是两个重要的工具,用于构建输入数据的管道。
(1)Dataset
是一个抽象类,表示数据集,需要实现__len__
和__getitem__
方法。
(2)DataLoader
是一个可迭代的数据加载器,它封装了数据集的加载、批处理、打乱和并行加载等功能。
Dataset
和DataLoader
Dataset
需要返回图像和对应的标签- from torch.utils.data import Dataset
- from PIL import Image
- import os
- import torch
-
- class ClassificationDataset(Dataset):
- def __init__(self, root_dir, transform=None):
- self.transform = transform
- self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
- self.labels = [...] # 这里应该是与图像对应的标签列表
-
- def __len__(self):
- return len(self.images)
-
- def __getitem__(self, idx):
- img_path = self.images[idx]
- image = Image.open(img_path).convert('RGB')
- label = self.labels[idx]
-
- if self.transform:
- image = self.transform(image)
-
- return image, label
DataLoader
加载数据- from torch.utils.data import DataLoader
-
- transform = ... # 这里定义你的数据预处理流程
- dataset = ClassificationDataset(root_dir='path_to_your_data', transform=transform)
- dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
Dataset
和DataLoader
Dataset
需要返回图像和对应的边界框信息- class DetectionDataset(Dataset):
- def __init__(self, root_dir, transform=None):
- self.transform = transform
- self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
- self.annotations = [...] # 这里应该是与图像对应的边界框信息列表
-
- def __len__(self):
- return len(self.images)
-
- def __getitem__(self, idx):
- img_path = self.images[idx]
- image = Image.open(img_path).convert('RGB')
- boxes = self.annotations[idx] # 这些是边界框信息
-
- if self.transform:
- image, boxes = self.transform(image, boxes)
-
- return image, boxes
DataLoader
加载数据dataloader = DataLoader(DetectionDataset(root_dir='path_to_your_data', transform=transform), batch_size=2, shuffle=True)
Dataset
和DataLoader
(1)Dataset
需要返回图像和对应的分割掩码- class SegmentationDataset(Dataset):
- def __init__(self, root_dir, transform=None):
- self.transform = transform
- self.images = [os.path.join(root_dir, img) for img in os.listdir(root_dir) if img.endswith('.jpg')]
- self.masks = [...] # 这里应该是与图像对应的分割掩码列表
-
- def __len__(self):
- return len(self.images)
-
- def __getitem__(self, idx):
- img_path = self.images[idx]
- mask_path = self.masks[idx]
- image = Image.open(img_path).convert('RGB')
- mask = Image.open(mask_path).convert('L') # 假设掩码是灰度图
-
- if self.transform:
- image, mask = self.transform(image, mask)
-
- return image, mask
(2)DataLoader
加载数据
dataloader = DataLoader(SegmentationDataset(root_dir='path_to_your_data', transform=transform), batch_size=4, shuffle=True)
在PyTorch的
Dataset
和DataLoader
框架中,idx
(或称为索引)是通过迭代DataLoader
时自动生成的。当你创建一个DataLoader
实例,并在训练循环中迭代它时,DataLoader
会内部调用Dataset
的__getitem__
方法,并自动为你提供索引idx
。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。