赞
踩
torchvision.datasets的三个基础类
torchvision.datasets
torch.utils.data.Dataset
Pillow(PIL Fork) Image模块
Dataset是数据集在pytorch中的化身,需要重写__ getitem__ 和 __ len__。__ __ getitem__ 通过传入的索引加载指定路径的数据,路径常常是一个列表,如很多张图片组成的数据集,需要在初始化时定义函数得到路径列表,或者在外部定义,总之要得到一个路径List。也需要在其中定义或调用具体读取的代码,如PIL库的Image.open()来读取图片,或Image.fromarray()来创建图片,也就是需要知道数据在哪里和怎么读取。
└─Dataset
└─VisionDataset
└─DatasetFolder
└─ImageFolder
Dataset是torch.utils.data中的类,是数据集的基础类
VisionDataset是torchvision.datasets.vision中的类,是torchvision类数据集的基础类,相比于原始的Dataset类,提供了transform,transforms,target_transform数据变换的接口
DatasetFolder,ImageFolder都来自torchvision.datasets.folder ,既然叫做folder,实际上已经有了完整的数据集功能,可以按照默认的目录结构读取数据。DatasetFolder还需要定义loader以读取特定类型的数据,和is_valid_file或者extensions,is_valid_file和extensions不能同时定义,但必须有一个定义,如果定义了有效后缀名,会自动通过后缀来判断文件有效性。而ImageFolder更进一步,默认使用读取图像数据的loader读取,还默认定义了图像后缀名。从Dataset到ImageFolder构成了不同层次的封装,完成度越高,灵活性越低,可以根据自己的需要选择。
除了在__ getitem__ 中通过得到的路径列表来读取数据,对于不同格式的数据也有不同的做法,如torchvision中内置cifar数据集,会直接从原始数据中以矩阵的形式读取, 因此 __ getitem__ 会从矩阵中创建Image对象。总而言之,一般来讲对于图片数据集来说,__ get __返回的都是PIL Image对象,不管是从路径列表中读取,还是整个以矩阵形式读取,如果不定义transform,最后在Dataset阶段都是PIL对象。
默认的排列结构如下,每一个文件夹表示一类,下面是这一类的样本
directory/ ├── class_x │ ├── xxx.ext │ ├── xxy.ext │ └── ... │ └── xxz.ext └── class_y ├── 123.ext ├── nsdf3.ext └── ... └── asd932_.ext
用文件夹来区分不同的类别。比较重要的有两类操作,find_class函数得到类别名和类别序号。make_dataset得到路径列表。
默认的findclass函数
文件夹名是类名。
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
默认的make_dataset
得到instance列表,表示文件的路径列表。
基本上很大一部分是在定义有效性判断相关,主要部分是一个双层for循环,因为类名定义为文件夹名,所以会遍历各个类的文件夹,会将遍历到的有效文件的路径加入instance,遍历过的非空类添加到available_classe。
def make_dataset( directory: Union[str, Path], class_to_idx: Optional[Dict[str, int]] = None, extensions: Optional[Union[str, Tuple[str, ...]]] = None, is_valid_file: Optional[Callable[[str], bool]] = None, allow_empty: bool = False, ) -> List[Tuple[str, int]]: directory = os.path.expanduser(directory) if class_to_idx is None: _, class_to_idx = find_classes(directory) elif not class_to_idx: raise ValueError("'class_to_index' must have at least one entry to collect any samples.") both_none = extensions is None and is_valid_file is None both_something = extensions is not None and is_valid_file is not None if both_none or both_something: raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") if extensions is not None: def is_valid_file(x: str) -> bool: return has_file_allowed_extension(x, extensions) # type: ignore[arg-type] is_valid_file = cast(Callable[[str], bool], is_valid_file) instances = [] available_classes = set() for target_class in sorted(class_to_idx.keys()): class_index = class_to_idx[target_class] target_dir = os.path.join(directory, target_class) if not os.path.isdir(target_dir): continue for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): for fname in sorted(fnames): path = os.path.join(root, fname) if is_valid_file(path): item = path, class_index instances.append(item) if target_class not in available_classes: available_classes.add(target_class) empty_classes = set(class_to_idx.keys()) - available_classes if empty_classes and not allow_empty: msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " if extensions is not None: msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" raise FileNotFoundError(msg) return instances
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
ImageFolder如名字所示,如果数据集是这种文件夹排列,而且是图像文件,又没有需要特殊定义的部分 ,可以直接实例化一个ImageFolder,而不需要重写任何部分 ,实例化一个数据集只需要传入数据集路径和tansform变换。
train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(
os.path.join(data_dir, 'train_valid_test', folder),
transform=transform_train) for folder in ['train', 'train_valid']]
一般会在数据集实例化时,从外部传入,通常自定义的Transforms序列包含ToTensor,可以将上一阶段的PIL Image转换为Tensor,而下一次变化要到训练时的to(device),这样数据最终输入完成,也可以在Dataset类中写入默认的transform。
通过torchvision.get_image_backend得到torchvision现在的后端默认为PILtorchvision.set_image_backend(backend)指定用来读取图片的包,可选accimage
Loader将数据读取为PIL对象,一般数据集定义不在数据集内部定义默认的transform图像变换,而是在外部定义一个transform序列,通常倒数第二个是torchvision.transforms.ToTensor()操作,会将一个PIL Image或者一个ndarray转换为tensor并缩放到[0.0, 1.0]。因此接下来会通过transforms.Normalize进行归一化。
PILToTensor会把PIL Image转化为tensor,但是不会进行缩放, ( H × W × C ) → ( C × H × W ) (H\times W\times C)\rightarrow (C\times H \times W) (H×W×C)→(C×H×W)
ToTensor会把PIL Image或者ndarray转换成tensor而且会进行缩放。 ( H × W × C ) → ( C × H × W ) (H\times W\times C)\rightarrow (C\times H \times W) (H×W×C)→(C×H×W) 在规定的模式如RGBA,RGB,YCbCr或者dtype = np.uint8情况下,别的情况下不缩放。
Normalize只支持tensor,其他大部分操作也支持PIL,所以在ToTensor之后最后进行Normalize
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
Pytorch Lightning是Pytorch中的kersas,简称pl
Pytorch Lightning继承LightningDataModule定义数据集,pl中的Dataset和Dataloader是高度耦合的。
import lightning.pytorch as L import torch.utils.data as data from pytorch_lightning.demos.boring_classes import RandomDataset class MyDataModule(L.LightningDataModule): def prepare_data(self): # download, IO, etc. Useful with shared filesystems # only called on 1 GPU/TPU in distributed ... def setup(self, stage): # make assignments here (val/train/test split) # called on every process in DDP dataset = RandomDataset(1, 100) self.train, self.val, self.test = data.random_split( dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42) ) def train_dataloader(self): return data.DataLoader(self.train) def val_dataloader(self): return data.DataLoader(self.val) def test_dataloader(self): return data.DataLoader(self.test) def teardown(self): # clean up state after the trainer stops, delete files... # called on every process in DDP ...
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。