当前位置:   article > 正文

PyTorch:数据读取1 - Datasets和TensorDataset

tensordataset

-柚子皮-

Datasets

在输入流水线中,准备数据的代码是这么写的

data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)

datasets.CIFAR10就是一个Datasets子类,data是这个类的一个实例。

为什么要定义Datasets?

PyTorch提供了一个工具函数torch.utils.data.DataLoader。通过这个类,我们可以让数据变成mini-batch,且在准备mini-batch的时候可以多线程并行处理,这样可以加快准备数据的速度。

Datasets就是构建这个类的实例的参数之一。

DataLoader的使用参考[PyTorch:数据读取2 - Dataloader]。

数据集划分

1 建议使用sklearn.preprocessing.model_selection

ds_train, ds_eval = model_selection.train_test_split(dataset, test_size=0.2, shuffle=args.if_shuffle_data)

2 train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

Note: dataloader应该是不能进行划分的。

[Pytorch划分数据集的方法]

自定义Datasets

框架

dataset必须继承自torch.utils.data.Dataset。

内部要实现两个函数:

        一个是__lent__用来获取整个数据集的大小;

        一个是__getitem__用来从数据集中得到一个数据片段item

import torch.utils.data as data
class CustomDataset(data.Dataset):
    """Custom data.Dataset compatible with data.DataLoader."""

    def __init__(self, filename, data_info, oth_params):
        """Reads source and target sequences from txt files."""
        # # # 从文件中读取数据
        self.file = open(filename, 'r')
        ...
        # # # 或者从外部数据结构data_info中读取数据
        self.all_texts = data_info['all_texts']
        self.all_labels = data_info['all_labels']

        # # # 构建字典,映射token和id
        self.vocab = data_info['vocab']

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        # # # 从文件读取
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform或者word2id什么的).
        # 3. Return a data pair(source and target) (e.g. image and label).

        # # # 或者直接读取
        item_info = {
            "text": self.all_texts[index],
            "label": self.all_labels[index]
        }
        return item_info

    def __len__(self):
        # return the total size of your dataset.
        return len(self.all_texts)

示例

从文件中读取数据写入Dataset

class Dataset(torch.utils.data.Dataset):
    def __init__(self, filepath=None,dataLen=None):
        self.file = filepath
        self.dataLen = dataLen
        
    def __getitem__(self, index):
        A,B,path,hop= linecache.getline(self.file, index+1).split('\t')
        return A,B,path.split(' '),int(hop)

    def __len__(self):
        return self.dataLen

随机mock一个分类数据

class Dataset(data.Dataset):
    """Custom data.Dataset compatible with data.DataLoader."""

    def __init__(self, df, lang: Lang):
        inputs_dim = vars(Config)['inputs_dim']
        self.x = torch.randint(0, 5, (5, inputs_dim), dtype=torch.float)

        self.label = torch.tensor([0, 0, 1, 1, 0, 1, 0, 1, 0, 1], dtype=torch.float)

        self.src_word2id = lang.word2id
        self.trg_word2id = lang.word2id
        # self.mem_word2id = mem_word2id

    def __getitem__(self, index):
        """Returns one data pair (source and target)."""
        x = self.x[index]
        label = self.label[index]

        item_info = {
            "x": x,
            "label": label
        }
        return item_info

官方MNIST的例子

(代码被缩减,只留下了重要的部分):

  1. class MNIST(data.Dataset):
  2. def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
  3. self.root = root
  4. self.transform = transform
  5. self.target_transform = target_transform
  6. self.train = train # training set or test set
  7. if download:
  8. self.download()
  9. if not self._check_exists():
  10. raise RuntimeError('Dataset not found.' +
  11. ' You can use download=True to download it')
  12. if self.train:
  13. self.train_data, self.train_labels = torch.load(
  14. os.path.join(root, self.processed_folder, self.training_file))
  15. else:
  16. self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
  17. def __getitem__(self, index):
  18. if self.train:
  19. img, target = self.train_data[index], self.train_labels[index]
  20. else:
  21. img, target = self.test_data[index], self.test_labels[index]
  22. # doing this so that it is consistent with all other datasets
  23. # to return a PIL Image
  24. img = Image.fromarray(img.numpy(), mode='L')
  25. if self.transform is not None:
  26. img = self.transform(img)
  27. if self.target_transform is not None:
  28. target = self.target_transform(target)
  29. return img, target
  30. def __len__(self):
  31. if self.train:
  32. return 60000
  33. else:
  34. return 10000

-柚子皮-

TensorDataset

TensorDataset本质上与python zip方法类似,对数据进行打包整合。
官方文档[torch.utils.data — PyTorch 2.0 documentation]

源码说明:r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Args: *tensors (Tensor): tensors that have the same size of the first dimension.
    """

该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。

  1. import torch
  2. from torch.utils.data import TensorDataset
  3. # a的形状为[4, 3], b的形状为[4], b的第一维与a相同
  4. a = torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]])
  5. b = torch.tensor([1, 2, 3, 4])
  6. train_data = TensorDataset(a, b)
  7. print(train_data[0])
  8. # (tensor([1, 1, 1]), tensor(1))
  9. print(train_data[0:2])
  10. # (tensor([[1, 1, 1],
  11. # [2, 2, 2]]), tensor([1, 2]))

取数据的时候,如上就是取每个tensor的下标对应数据后再组合成类似tuple的对象。 

from: -柚子皮-

ref: [pytorch学习笔记(六):自定义Datasets]

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/455112
推荐阅读
相关标签
  

闽ICP备14008679号