赞
踩
- import torch
- import torch.utils.data as Data
- import numpy as np
-
- data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
- label = np.asarray([[0], [1], [0], [2]])
-
- #创建子类
- class subDataset(Data.Dataset):
- #初始化,定义数据内容和标签
- def __init__(self, Data, Label):
- super(subDataset, self).__init__()
- self.Data = Data
- self.Label = Label
- #返回数据集大小
- def __len__(self):
- return len(self.Data)
- #得到数据内容和标签
- def __getitem__(self, index):
- data = torch.Tensor(self.Data[index])
- label = torch.Tensor(self.Label[index])
- return data, label
-
- if __name__ == '__main__':
- dataset = subDataset(data, label)
- print(dataset)
- print('dataset大小为 :', dataset.__len__())
- print(dataset.__getitem__(0))
- print(dataset[0])
- print()
-
-
- print('----------创建DataLoader迭代器-----------')
- dataloader = Data.DataLoader(dataset, batch_size= 2, shuffle = False, num_workers= 4)
- for i, item in enumerate(dataloader):
- print('batch_i:', i)
- data, label = item
- print('data:', data)
- print('label:', label)
以上代码如果在jupyter notebook中运行会报错,因为一般 jupyter notebook 是单核运行的。
解决方案:
使用正常py文后缀的文件类型运行代码,既可多线程运行
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。