赞
踩
在阅读之前,需要配置好对应pytorch版本。
对于一般学习,使用cpu版本的即可。参考教程点我
导入pytorch包,使用如下命令即可。
import torch # 注意虽然叫pytorch,但是在引用时是引用torch
神经网络获取数据主要用到Dataset和Dataloader两个方法
Dataset主要用于获取数据以及对应的真实label
Dataloader主要为后面的网络提供不同的数据形式
在torch.utils.data包内提供了DataSet类,可在Pytorch官网看到对应的描述
class Dataset(Generic[T_co]): r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. Subclasses could also optionally implement :meth:`__getitems__`, for speedup batched samples loading. This method accepts list of indices of samples of batch and returns list of samples. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """ def __getitem__(self, index) -> T_co: raise NotImplementedError("Subclasses of Dataset should implement __getitem__.") # def __getitems__(self, indices: List) -> List[T_co]: # Not implemented to prevent false-positives in fetcher check in # torch.utils.data._utils.fetch._MapDatasetFetcher def __add__(self, other: "Dataset[T_co]") -> "ConcatDataset[T_co]": return ConcatDataset([self, other]) # No `def __len__(self)` default? # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] # in pytorch/torch/utils/data/sampler.py
根据上述描述可知,Dataset是一个抽象类,用于表示数据集。你可以通过继承这个类并实现以下方法来自定义数据集:
__len__(self): 返回数据集的大小,即数据集中有多少个样本。
__getitem__(self, idx): 根据索引 idx 返回数据集中的一个样本和对应的标签。
使用Dataset读取文件夹E:\Python_learning\Deep_learning\dataset\hymenoptera_data\train\ants下所有图片。并获取对应的label,该数据集的文件夹的名字为对应的标签,而文件夹内为对应的训练集的图片。
import os from torch.utils.data import Dataset from PIL import Image from torch.utils.tensorboard import SummaryWriter from torchvision import transforms class MyDataset(Dataset): def __init__(self, root_path, label): self.root_path = root_path self.label = label self.img_path = os.path.join(root_path, label) # 拼接路径 print(f"图片路径: {self.img_path}") # 打印路径以进行调试 try: self.img_path_list = os.listdir(self.img_path) # 列出文件夹中的文件 print(f"图片列表: {self.img_path_list}") # 打印图片列表以进行调试 except PermissionError as e: print(f"权限错误: {e}") except FileNotFoundError as e: print(f"文件未找到错误: {e}") def __getitem__(self, index): img_index = self.img_path_list[index] img_path = os.path.join(self.img_path, img_index) try: img = Image.open(img_path) except Exception as e: print(f"读取图片时出错: {e}, 图片路径: {img_path}") raise e label = self.label return img, label def __len__(self): return len(self.img_path_list) # 实例化这个类 my_data = MyDataset(root_path=r'E:\Python_learning\Deep_learning\dataset\hymenoptera_data\train', label='ants') writer = SummaryWriter('logs') for i in range(my_data.__len__()): img, label = my_data[i] # 依次获取对应的图片 # 此处img为PIL Image, 使用transforms中的ToTensor方法转化为tensor格式 writer.add_image(tag=label, img_tensor=transforms.ToTensor()(img), global_step=i) writer.close() print(f"当前文件夹下{i + 1}张图片已读取完毕,请在Tensorboard中查看")
在控制台输入tensorboard --logdir='E:\Python_learning\Deep_learning\note\logs'
打开tensorboard查看
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。