赞
踩
以二分类为例子,创建一个简单的训练数据集
import os def gen_txt_file(root_dir, txt_path): """制作训练数据集的txt,标签文件 Args: root_dir: 训练集/测试集/验证集 目录 txt_path: 生成的txt文件路径 """ if not os.path.isfile(txt_path): return for file in os.listdir(root_dir): label = 0 img_file = os.path.join(root_dir, file) print(img_file) if "dog" in img_file: label = 1 with open(txt_path, "a") as fp: fp.write(f"{img_file}_{label}\n") def main(): gen_txt_file(root_dir="/media/tx-deepocean/Data/DICOMS/demos/torch_datasets/cats", txt_path="/media/tx-deepocean/Data/DICOMS/demos/Projects/pytorch-tutorial/txd_learn_notes/test.txt") if __name__ == "__main__": main()
以下是我的生成结果:
为了满足pytorch 模型对数据集的规范,需要按照官方要求制定符合要求的数据集, 分别重写__init__(self), __ getitem__(self, index), __ len__(self)方法:
""" torchvision 是pytorch 中专门用来处理图像的库,含有四个大类 torchvision.datasets 加载数据集 torchvision.models 提供一些已经训练好的模型 torchvision.transforms 提供图像处理需要的工具, resize, crop, data_augmentation torchvision.utils """ from loguru import logger import torchvision import torch import os import skimage.io as io class CustomDataset(torch.utils.data.Dataset): def __init__(self, root_dir, names_file, transform=None): # 1. Initialize file paths or a list of file names. self.root_dir = root_dir self.names_file = names_file self.transform = transform self.size = 0 self.names_list = [] if not os.path.isfile(self.names_file): print(f'{self.names_file} is not exists') file = open(self.names_file) print(file) for f in file: self.names_list.append(f) self.size += 1 print(self.names_list) def __getitem__(self, index): # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). logger.info(f'') img_path = os.path.join(self.root_dir, self.names_list[index].split(" ")[0]) logger.info(f'img_path: {img_path}') if not os.path.isfile(img_path): logger.warning(f'{img_path} not exists!') return None image = io.imread(img_path) label = int(self.names_list[index].split(" ")[1]) logger.info(label) sample = {'image': image, 'label': label} if self.transform: sample = self.transform(sample) return sample def __len__(self): # You should change 0 to the total size of your dataset. return self.size # You can then use the prebuilt data loader. custom_dataset = CustomDataset(root_dir="", names_file="/media/tx-deepocean/Data/DICOMS/demos/Projects/pytorch-tutorial/txd_learn_notes/test.txt") # print(custom_dataset.__getitem__(0)) train_loader = torch.utils.data.DataLoader(dataset=custom_dataset, batch_size=64, shuffle=True) print(train_loader)
到此,一个Pytorch 的自定义数据集就制作完成了.
一个值得信赖的女人.
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。