当前位置:   article > 正文

动手学深度学习——图像分类数据集(代码详解)

动手学深度学习——图像分类数据集(代码详解)

1. 图像分类数据集

这里采用Fashion-MNIST数据集

  • torchvision:torch类型的可视化包,一般计算机视觉和数据可视化需要使用
  • from torchvision import transforms:该组件经常用于图片的修改(一般数据集中的图片都是PIL格式,使用的时候需要转化为tenser,而在加入函数时常需要转化为nadarry(numpy中的ndarray为多维数组))
  • d2l.use_svg_display():使用什么模式展示图片
%matplotlib inline
import torch
import torchvision #pytorch用于计算机视觉的一个库
from torch.utils import data
from torchvision import transforms #导入对数据操作的模具
from d2l import torch as d2l

d2l.use_svg_display() #使用svg展示图片
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

1.1 读取数据集

通过框架中的内置函数将Fashion-MNIST数据集下载并读取到内存中

  • torchvision.datasets:一般用于图像数据集的下载和获取
    eg:
  • torchvision.datasets.FashionMNIST( root=, train=True, transform=, download=True)
    • train:是否为训练集
    • transform:使用什么格式转换(可以从transforms组件中选择)
    • dowload:是否下载对应数据集
    • .FashionMNIST可以更换为其他数据源
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor() #对图片进行预处理,转换为tensor格式

# 下载训练集和测试集,并保存
mnist_train = torchvision.datasets.FashionMNIST(
	root="../data", train=True, transform=trans,download=True)
mnist_train = torchvision.datasets.FashionMNIST(
	root="../data", train=False, transform=trans,download=True)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。

# 输出训练集和测试集的大小
len(mnist_train), len(mnist_test)
  • 1
  • 2

在这里插入图片描述
每个输入图像的高度和宽度均为28像素。 数据集由灰度图像组成,其通道数为1(彩色图像通道数为3)。

# 索引到第一张图片
mnist_train[0][0].shape # 输入图像的通道数、高度和宽度
  • 1
  • 2

在这里插入图片描述
Fashion-MNIST中包含的10个类别,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数用于在数字标签索引及其文本名称之间进行转换。

# 获取数据集的标签
def get_fashion_mnist_labels(labels): #@save
	"""返回Fashion-MNIST数据集的文本标签"""
	text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_lables[int(i)] for i in labels]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

创建一个函数来可视化这些样本。

  • plt.subplots()是一个返回包含图形和轴对象的元组的函数。因此,在使用时fig, ax = plt.subplots(),将此元组解压缩到变量fig和ax。
  • enumerate()函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中,生成可以遍历的每个元素有对应序号(0, 1, 2, 3…)的enumerate对象。
  • zip()函数用于将多个可迭代对象作为参数,依次将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的对象,里面的每个元素大概为i,(ax,img)的形式。
  • imshow()可以接收二维,三维甚至多维数组。二维默认为一通道即灰度图像,三维需要在第三个维度指定图像通道数(必须是第三维)
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
	"""绘制图像列表"""
	figsize = (num_cols * scale, num_rows * scale)
	
	# 第1个参数是个图,一般不用;第2个axer类似于图片的索引矩阵(行,列)
	_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) # axes:轴
	axes = axes.flatten()

	# 遍历生成形如i, (ax, img)形式的enumerate对象
	for i, (ax, img) in enumerate(zip(axes, imgs)):
		if torch.is_tensor(img):
			# 图片张量
			ax.imshow(img.numpy())
			
		else:
			# PIL图片
			ax.imshow(img)
		ax.axes.get_xaxis().set_visible(False) #x轴隐藏
		ax.axes.get_yaxis().set_visible(False) #y轴隐藏
		if titles:
			ax.set_title(title[i]) #显示标题
	return axes
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

以下是训练数据集中前几个样本的图像及其相应的标签。

  • next() 返回迭代器的下一个项目。
  • next() 函数要和生成迭代器的iter() 函数一起使用。
  • 我们可以通过iter()函数获取这些可迭代对象的迭代器。然后,我们可以对获取到的迭代器不断使⽤next()函数来获取下⼀条数据。
    注:当我们已经迭代完最后⼀个数据之后,再次调⽤next()函数会抛出 StopIteration的异常 ,来告诉我们所有数据都已迭代完成,不⽤再执⾏ next()函数了。
# 使用next()函数获取批量大小为18的训练集的图像和标签
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))

#显示18张图片,宽度为28,长度为28,总共为2行9列
# 绘制两行图片,每一行有9张图片,并获取标签
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y)); 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

在这里插入图片描述

1.2 读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用内置的数据迭代器,而不是从零开始创建。 回顾一下,在每次迭代中,数据加载器每次都会读取一小批量数据,大小为batch_size。 通过内置数据迭代器,我们可以随机打乱了所有样本,从而无偏见地读取小批量。

batch_size = 256

def get_dataloader_workers(): #@save
	"""使用4个进程来读取数据"""
	return 4

# 训练集需要设置shuffle=True打乱顺序	
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
							 num_workers=get_dataloader_workers())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

我们看一下读取训练数据所需的时间。

timer = d2l.Timer() #调用Timer函数,测试速度
for X, y in train_iter:
	continue
f'{timer.stop():.2f} sec' #输出读取数据所用的秒数,精度为2位小数
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

1.3 整合所有组件

定义load_data_fashion_mnist函数,用于获取和读取Fashion-MNIST数据集。这个函数返回训练集和验证集的数据迭代器。 此外,这个函数还接受一个可选参数resize,用来将图像大小调整为另一种形状。

  • torchvision.transforms是pytorch中的图像预处理包,一般用Compose把多个步骤整合到一起。
  • insert函数是一种用于列表的内置函数。这个函数的作用是在一个列表中的指定位置,插入一个元素。
transforms中的函数功能
Resize把给定的图片resize到given size
Normalize用均值和标准差归一化张量图像
def load_data_fashion_mnist(batch_size, resize=None):  #@save
	"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
	# 转换为tensor
	trans = [transforms.ToTensor()]

	
	if resize:
		trans.insert(0, transforms.Resize(resize))
	# compose整合步骤
	trans = transforms.Compose(trans)

	# 下载训练集和测试集,将小批量样本返回到train_iter中,用于之后的训练
	mnist_train = torchvision.datasets.FashionMNIST(
		root="../data", train=True, transform=trans, download=True)
	mnist_test = torchvision.datasets.FashionMNIST(
		root="../data", train=False, transform=trans, download=True)
	return (data.DataLoader(mnist_train, batch_size, shuffle=True,
							num_workers=get_dataloader_workers()),
			data.DataLoader(mnist_test, batch_size, shuffle=False,
							num_workers=get_dataloader_workers()))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

下面,我们通过指定resize参数来测试load_data_fashion_mnist函数的图像大小调整功能。

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
	print(X.shape, X.dtype, y.shape, y.dtype)
	break
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

1.4 小结

  • Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成。我们将在后续章节中使用此数据集来评估各种分类算法。
  • 我们将高度h像素,宽度w像素图像的形状记为h×w或(h,w)。
  • 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/373313?site
推荐阅读
相关标签
  

闽ICP备14008679号