赞
踩
构建数据集需要继承torch.utils.data.dataset的Dataset类重写init,getitem(self, mask),len三个方法。然后使用torch.utils.data import DataLoader来加载你创建的数据集Dataset。
import argparse import os import random import shutil import time import warnings import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.distributed as dist import torch.optim import torch.utils.data import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models import numpy as np import os, imageio from torch.utils.data.dataset import Dataset class MyDataSet(Dataset): def __init__(self, data, label):#传入参数是我们的数据集(data)和标签集(label) self.data = data self.label = label self.length = data.shape[0] def __getitem__(self, mask):# 获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。 label = self.label[mask] data = self.data[mask] return label, data def __len__(self): # print(self.length) return self.length train_set = MyDataSet(xb,yb)# xb,yb为所有的数据 # train_set = MyDataSet(data=X_train, label=Y_train) num_epoch = 100 # number of epochs to train on batch_size = 1024 # training batch size train_data = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True) class MLP(nn.Module): def __init__(self,depth=4,mapping_size=2,hidden_size=256): super().__init__() layers = [] layers.append(nn.Linear(mapping_size,hidden_size)) layers.append(nn.ReLU(inplace=True)) for _ in range(depth-2): layers.append(nn.Linear(hidden_size,hidden_size)) layers.append(nn.ReLU(inplace=True)) layers.append(nn.Linear(hidden_size,3)) self.layers = nn.Sequential(*layers) def forward(self,x): return torch.sigmoid(self.layers(x)) model = MLP() for epoch in range(num_epoch ): model.train() for batchsz, (label, data) in enumerate(train_data): # i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels print("第 {} 个Batch size of label {} and size of data{}".format(batchsz, label.shape, data.shape))
from torch.utils.data.dataset import Dataset
class MyDataSet(Dataset):
def __init__(self, data, label):#传入参数是我们的数据集(data)和标签集(label)
self.data = data
self.label = label
self.length = data.shape[0]
def __getitem__(self, mask):# 获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。
label = self.label[mask]
data = self.data[mask]
return label, data
def __len__(self):
# print(self.length)
return self.length
image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width)
IndexError: list index out of range
先加个捕获异常:
def __getitem__(self, index): video_name = self.samples[index].split('/')[-2] frame_name = int(self.samples[index].split('/')[-1].split('.')[-2]) batch = [] for i in range(self._time_step+self._num_pred): try: image = np_load_frame(self.videos[video_name]['frame'][frame_name+i], self._resize_height, self._resize_width) except : print('error from --- model utils') print(frame_name) print(i) if self.transform is not None: batch.append(self.transform(image)) return np.concatenate(batch, axis=0)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。