当前位置:   article > 正文

python DataSet+ Dataloader 深度学习编程细节_数据集 pytorchDataset的构建与使用_nn.dataset

nn.dataset
  • 深度学习中许多网络的设计都需数据集的预处理功能辅助,本文对DataSet + Dataloader 的使用做介绍。

DataSet构建(简单示例)

        构建数据集需要继承torch.utils.data.dataset的Dataset类重写init,getitem(self, mask),len三个方法。然后使用torch.utils.data import DataLoader来加载你创建的数据集Dataset。

import argparse
import os
import random
import shutil
import time
import warnings
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

import numpy as np
import os, imageio


from torch.utils.data.dataset import Dataset
class MyDataSet(Dataset):
    def __init__(self, data, label):#传入参数是我们的数据集(data)和标签集(label)
        self.data = data
        self.label = label
        self.length = data.shape[0]

    def __getitem__(self, mask):# 获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。
        label = self.label[mask]
        data = self.data[mask]
        return label, data

    def __len__(self):
        # print(self.length)
        return self.length



train_set = MyDataSet(xb,yb)# xb,yb为所有的数据
# train_set = MyDataSet(data=X_train, label=Y_train)
num_epoch = 100     # number of epochs to train on
batch_size = 1024  # training batch size
train_data = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)

class MLP(nn.Module):
    def __init__(self,depth=4,mapping_size=2,hidden_size=256):
        super().__init__()
        layers = []
        layers.append(nn.Linear(mapping_size,hidden_size))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Linear(hidden_size,hidden_size))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(hidden_size,3))
        self.layers = nn.Sequential(*layers)
    def forward(self,x):
        return torch.sigmoid(self.layers(x))
model = MLP()
for epoch in range(num_epoch ):
    model.train()
    for batchsz, (label, data) in enumerate(train_data):
        # i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
        print("第 {} 个Batch size of label {} and size of data{}".format(batchsz, label.shape, data.shape))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64

图像的分割处理数据集的构建

添加链接描述
添加链接描述

构建自监督任务的数据集(用一个数据集构建正负样本)

from torch.utils.data.dataset import Dataset
class MyDataSet(Dataset):
    def __init__(self, data, label):#传入参数是我们的数据集(data)和标签集(label)
        self.data = data
        self.label = label
        self.length = data.shape[0]

    def __getitem__(self, mask):# 获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。
        label = self.label[mask]
        data = self.data[mask]
        return label, data

    def __len__(self):
        # print(self.length)
        return self.length
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

C&G

后续(+捕获异常)

image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)
IndexError: list index out of range
  • 1
  • 2

先加个捕获异常:

    def __getitem__(self, index):
        video_name = self.samples[index].split('/')[-2]
        frame_name = int(self.samples[index].split('/')[-1].split('.')[-2])

        batch = []
        for i in range(self._time_step+self._num_pred):
            try:
                image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)
            except :
                print('error from --- model utils')
                print(frame_name)
                print(i)
            if self.transform is not None:
                batch.append(self.transform(image))

        return np.concatenate(batch, axis=0)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/455090
推荐阅读
相关标签
  

闽ICP备14008679号