当前位置:   article > 正文

Pytorch 创建自定义数据集_pytorch自定义生成数据集

pytorch自定义生成数据集

Pytorch 创建自定义数据集 [简单,快速], 全程无废话

二分类为例子,创建一个简单的训练数据集

环境介绍

  • tensorboard 2.11.2
  • torch 1.13.1
  • torchvision 0.14.1
  • pyhon 3.7.5
  • scikit-image 0.19.3
  • nvidia-cublas-cu11 11.10.3.66
  • nvidia-cuda-nvrtc-cu11 11.7.99
  • nvidia-cuda-runtime-cu11 11.7.99
  • nvidia-cudnn-cu11 8.5.0.96

生成标签数据,txt 脚本文件

import os

def gen_txt_file(root_dir, txt_path):
    """制作训练数据集的txt,标签文件
    Args:
        root_dir: 训练集/测试集/验证集 目录
        txt_path: 生成的txt文件路径
    """
    if not os.path.isfile(txt_path):
        return
    for file in os.listdir(root_dir):
        label = 0
        img_file = os.path.join(root_dir, file)
        print(img_file)
        if "dog" in img_file:
            label = 1
        with open(txt_path, "a") as fp:
            fp.write(f"{img_file}_{label}\n")


def main():
    gen_txt_file(root_dir="/media/tx-deepocean/Data/DICOMS/demos/torch_datasets/cats", txt_path="/media/tx-deepocean/Data/DICOMS/demos/Projects/pytorch-tutorial/txd_learn_notes/test.txt")

if __name__ == "__main__":
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

以下是我的生成结果:
txt标签文件

重写torch.utils.data.Dataset 类

为了满足pytorch 模型对数据集的规范,需要按照官方要求制定符合要求的数据集, 分别重写__init__(self), __ getitem__(self, index), __ len__(self)方法:

"""
torchvision 是pytorch 中专门用来处理图像的库,含有四个大类
    torchvision.datasets 加载数据集
    torchvision.models 提供一些已经训练好的模型
    torchvision.transforms 提供图像处理需要的工具, resize, crop, data_augmentation
    torchvision.utils
"""
from loguru import logger
import torchvision
import torch
import os
import skimage.io as io  

class CustomDataset(torch.utils.data.Dataset):    
    def __init__(self, root_dir, names_file, transform=None):
        # 1. Initialize file paths or a list of file names. 

        self.root_dir = root_dir
        self.names_file = names_file
        self.transform = transform
        self.size = 0
        self.names_list = []

        if not os.path.isfile(self.names_file):
            print(f'{self.names_file}  is not exists')
        file = open(self.names_file)
        print(file)
        for f in file:
            self.names_list.append(f)
            self.size += 1  
        print(self.names_list)      


    def __getitem__(self, index):
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        logger.info(f'')
        img_path = os.path.join(self.root_dir, self.names_list[index].split(" ")[0])
        logger.info(f'img_path: {img_path}')
        if not os.path.isfile(img_path):
            logger.warning(f'{img_path} not exists!')
            return None
        image = io.imread(img_path)
        label = int(self.names_list[index].split(" ")[1])
        logger.info(label)
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        return sample        

    def __len__(self):
        # You should change 0 to the total size of your dataset.
        return self.size

# You can then use the prebuilt data loader. 
custom_dataset = CustomDataset(root_dir="", names_file="/media/tx-deepocean/Data/DICOMS/demos/Projects/pytorch-tutorial/txd_learn_notes/test.txt")
# print(custom_dataset.__getitem__(0))
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
                                           batch_size=64, 
                                           shuffle=True)
print(train_loader)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62

到此,一个Pytorch 的自定义数据集就制作完成了.
一个值得信赖的女人.
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/567738
推荐阅读
相关标签
  

闽ICP备14008679号