赞
踩
当我们有自己爬取的图片时,我们想要通过这些图片来训练自己的网络,该怎么办呢?(看不懂加我QQ探讨3014457121)
那就是继承Dataset类
想要实现这个类,必须要重写3个方法:init(self, 参数…)、 getitem(self, index)、len(self)
1、首先我们来介绍__init__(self, 参数…)
这个函数你可以传多个参数进来,以用来初始化你自己类中的各个变量。例如我传的是图片的母文件夹路径root
def __init__(self, root):
# 所有图片的绝对路径,root只定位到文件夹一级,os.listdir()只返回该文件下的文件名
imgs = os.listdir(root) # 这边先传入全部的数据
self.imgs = [os.path.join(root, k) for k in imgs]
self.transforms = None
self.transforms = transform
imgs:拿到了所有图片的名字并形成一个list类型的列表。
然后类中的变量self.imgs就获得了所有图片的路径并形成一个列表。
下面我们来说说transforms变量是啥。
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x), # 把通道数扩大3倍
transforms.Resize(600), # 让最短边为600
transforms.CenterCrop(600), # 从中间切出600*600的图片
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
transforms.ToTensor():让图片变成pytorch框架所需要的张量形式(CHW维度)
transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x):我爬取的图片里有些是1通道的,即C=1,但是pytorch必须要保持数据集的维度格式一模一样,所以我要把那些通道数是1的图片强行转为通道数为3
接下来2行看注释
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]):对得到的张量里的数字进行概率分布,用为这对训练友好。
2、再来看看__getitem__(self, index)方法:
主要用来根据传入的索引来返回某个数据,像这样用的:
train_dataset = FlameSet('./images')
for i in range(73):
print(train_dataset[i].shape)
下面是函数体部分:
def __getitem__(self, index): # 这个函数利用索引读出单独的数据
img_path = self.imgs[index]
pil_img = Image.open(img_path)
if self.transforms: # 有transform(最好用这个)
data = self.transforms(pil_img) # C, H, W
else: # 无transform(尽量别用)
pil_img = np.asarray(pil_img)
data = torch.from_numpy(pil_img) # H,W,C
data = data.permute(2, 0, 1) # ==> C,H,W 维度转变了
return data
这个函数一般只需要传入一个index索引变量
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。