当前位置:   article > 正文

初识图像数据集_图片数据集

图片数据集

用于深度学习的图像数据集的载入分为两个步骤:

Task 1:将图像数据集载入内存

在 Pytorch 的 torchvision 模块中内置了许多CV领域的 benchmark 数据集,如 Fashion-MINIST, CIFAR-100,Image-Net 等,都可以很方便地直接加载。内置数据集引用格式可在torchvision官方参考文档中查看:
Torchvision Built-In Datasets API – Documentation 0.13

除了内置数据集外,也可以通过一些Python模块载入不同类型的其它数据集,最后转换成Tensor类型数据。

Task 2:构建数据迭代器用于分批次读取深度学习的训练数据

在使用梯度下降进行深度学习的训练优化过程中,每一次更新参数之前都需要遍历整个数据集计算梯度,这导致训练速度非常缓慢。

对此我们通常采用小批量随机梯度下降法,将训练集拆分为数个批次 (batch)并构成一个序列,通过for…in的方式每次读取一个batch进行梯度计算和权重更新。这种能够通过for…in语句读取的序列叫做可迭代对象

可迭代的batch序列可以通过保留字yield 生成,示例如下:

# 从整个样本数据集中读取小批量
# 定义采样函数
def data_iter(batch_size,features,labels):
    #特征矩阵的长度即样本数量
    num_examples = len(features)
    #建立整个样本集的索引序号,生成列表格式数据
    indices = list(range(num_examples))
    #将索引序号随机排列
    random.shuffle(indices)
    
    #生成采样序列
    for i in range(0,num_examples,batch_size):
        batch_indices = indices[i:min(i+batch_size,num_examples)]
        # 利用采样索引读取一批特征和标签数据
        yield features[batch_indices],labels[batch_indices]
batch_size = 10
for X,y in data_iter(batch_size,features,labels):
    print(X,'\n',y)
    break
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

但一般来说,我们通过torch.utils.data模块的Dataloader类可以更方便地执行数据集的拆分和可迭代对象生成,详见后文。

1. 载入torchvision内置的Fashion-MINIST数据集,观察数据结构

数据载入
from torchvision import transforms
# 从torchvision的transform 模块中引用ToTensor类
trans = transforms.ToTensor()

'''从torchvision内置的数据库中读取FashionMINIST数据集'''
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
transform 参数设置。

我们在载入数据前通过引用torchvision.transforms模块,定义数据的变换方式。

  • Pytorch提供的图像变换包括对PIL格式图片,Tensor格式图片进行灰度、分辨率、中心对象等参数转换。同时也提供PIL格式与Tensor格式的相互转换。
  • Tensor格式指的是一张图片通过张量(C,H,W)表示,C代表色彩通道,H为图像高度,W表示图像宽度。对于一批(batch)图片,其描述张量为(B,C,H,W)其中 B 表示该批次的图片数量。张量图像数值值的期望范围由张量中的数字类型隐式定义。具有浮点类型的张量图像数值取值范围为 [ 0 , 1 ) [0,1) [0,1)。具有整数类型的张量图像数值取值范围为 [ 0 , m a x ( D t y p e ) ] [0,max(Dtype)] [0,max(Dtype)],其中 m a x ( D t y p e ) max(Dtype) max(Dtype) 是数字类型表示的最大值。
  • PIL图片指的是Python的Pillow库所支持的图片格式,包括L(灰度8位像素), RGB(真彩 3 × 8 3\times8 3×8位像素), YCbCr(色彩视频格式 3 × 8 3\times8 3×8位像素), RGBA(真彩+透明通道 4 × 8 4\times8 4×8位像素), CMYK(印刷四色模式或彩色印刷模式 4 × 8 4\times8 4×8位像素)等多种类型图像。
  • 多种图像转换功能可通过transform.compose() 串联起来。
 trans = transforms.Compose([transforms.CenterCrop(10),
 					transforms.PILToTensor(),
 					transforms.ConvertImageDtype(torch.float),])
  • 1
  • 2
  • 3

在这里,我们使用 transforms.ToTensor(),将图像数据从PIL类型变换成torch张量(C,H,W)32位浮点数格式,并除以255使得所有像素的数值均在0到1之间。

数据解析

可以看到,Fashion-MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。
因此,训练集和测试集分别包含60000和10000张图像。‘’’

[In]: len(mnist_train), len(mnist_test)
[Out]: (60000, 10000)
  • 1
  • 2

数据集由灰度图像组成,故其通道数为1。每张图像的高度和宽度均为28像素。

[In]: mnist_train[0][0].shape
[Out]: torch.Size([1, 28, 28])
  • 1
  • 2

转换后的数据类型为元组,由每个图像的单通道(灰度)高-宽矩阵(C x H x W = 1 x 28 x 28 )和类别标签两个元素构成。

# 预览第一张图片的张量
mnist_test[0]
  • 1
  • 2

一个batch的张量图像数据结构如下图表示:
在这里插入图片描述

张量中代表类别的标签由数字 0 − 9 0-9 09 表示,我们可以将图像集的数字标签映射到文本标签上.

[In]: def get_fashion_mnist_labels(labels):
		text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
      return [text_labels[int(i)] for i in labels]
    
[In]: get_fashion_mnist_labels(range(5))
[Out]: ['t-shirt', 'trouser', 'pullover', 'dress', 'coat']
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
使用matplotlib.pyplot模块预览张量图像数据
import matplotlib.pyplot as plt
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):  #@save
    """绘制图像列表"""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # 图片张量
            ax.imshow(img.numpy())
        else:
            # PIL图片
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes
    
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

Fashion-MNIST图片预览

3. 不同格式的外部数据集载入

利用Pandas 可以将csv,excel,sql等格式文件载入为DataFrame格式,再通过transform方法转换为Tensor类型。

另外,我们在下载网络数据集时可能还常会遇到通过序列化方式储存的二进制文件。
Pandas 读取序列化文件的方式为:

import pandas as pd
train_data = pd.read_pickle(<数据文件路径>, compression='infer', storage_options=None)
  • 1
  • 2
  • Pickle模块是Python执行序列化和反序列化的标准库.

Numpy数组,Torch张量等 Python对象都可以通过pickle.dump()函数进行序列化(或串行化),永久储存为二进制文件。

使用时则可以通过pickle.load() 函数进行反序列化,将二进制文件恢复为 Python对象。可以定义如下函数读取:

def unpickle(file):
    import pickle
    with open(file, 'rb') as <定义文件变量名>:
        img_data = pickle.load(<文件变量名>, encoding='bytes')
    return img_data

train_data = unpickle(<数据文件路径>)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

2. 使用DataLoader分批次读取

在将数据载入内存之后,我们通过torch.utils.data模块中的内置数据迭代器来分批次读取数据,以前文载入的Fashion-MNIST数据为例。

batch_size = 256
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size, num_workers=4, shuffle=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(mnist_val, batch_size, num_workers=4, shuffle=False) 
  • 1
  • 2
  • 3

通过batch_size指定每个batch中图片的数量。num_workers指定读取数据的进程数,即batch的个数。shuffle = True 指明在分割批次之前将数据的顺序打乱。

4. 以函数的方式定义数据加载器

同样以Fashion-MNIST数据集为例,分别通过函数与类的方式定义一个数据读取器。
以函数的方式定义,将torchvision的内置数据集加载到内存中并返回数据迭代器。

def load_data_fashion_mnist(batch_size, resize=None):  
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
    print(X.shape, X.dtype, y.shape, y.dtype)
    break                           
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

5. 以类的方式自定义数据加载器

通过对utils.Dataset类继承的方式定义一个数据加载器。通过定义__len____getitem__方法,返回可传入DataLoader的训练与测试数据。

class FMDataset(Dataset):
    def __init__(self, df, transform=None):
        '''数据为通过pandas读取csv的DataFrame数据'''
        self.df = df
        '''自定义图像格式转换属性'''
        self.transform = transform
        '''DataFrame每一行代表一张图片,
        第一列为类别标签,
        第二列开始为每个像素的灰度值,数据类型转换为8位无符号整型numpy数组'''
        self.images = df.iloc[:,1:].values.astype(np.uint8)
        self.labels = df.iloc[:, 0].values
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        '''将每张图片灰度值的一维数组转换为1*28*28的三维数组'''
        image = self.images[idx].reshape(1,28,28)
        
        label = int(self.labels[idx])
        '''有transform属性的参数传入,调用.transform()方法'''
        if self.transform is not None:
            image = self.transform(image)
       '''无transform属性的参数传入,
       通过将灰度值除以255标准化为[0,1)浮点数类型张量'''     
        else:
            image = torch.tensor(image/255., dtype=torch.float)
        '''将类别标签也转换为张量类型'''
        label = torch.tensor(label, dtype=torch.long)
        return image, label

train_df = pd.read_csv("./FashionMNIST/fashion-mnist_train.csv")
test_df = pd.read_csv("./FashionMNIST/fashion-mnist_test.csv")
train_data = FMDataset(train_df, data_transform)
test_data = FMDataset(test_df, data_transform)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/260009
推荐阅读
相关标签
  

闽ICP备14008679号