当前位置:   article > 正文

深度学习实战模拟——softmax回归(图像识别并分类)_softmax 回归:图片分类

softmax 回归:图片分类

目录

1、数据集:

2、完整代码


1、数据集:

1.1 Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成,分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。

1.2 Fashion‐MNIST由10个类别的图像组成,每个类别由训练数据集(train dataset)中的6000张图像和测试数据 集(test dataset)中的1000张图像组成。因此,训练集和测试集分别包含60000和10000张图像。测试数据集 不会用于训练,只用于评估模型性能。

以下函数用于在数字标签索引及其文本名称之间进行转换。

  1. # 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
  2. # 并除以255使得所有像素的数值均在0~1之间
  3. trans = transforms.ToTensor()
  4. mnist_train = torchvision.datasets.FashionMNIST(
  5. root="../data", train=True, transform=trans, download=True)
  6. mnist_test = torchvision.datasets.FashionMNIST(
  7. root="../data", train=False, transform=trans, download=True)

以下函数用于在数字标签索引及其文本名称之间进行转换。

  1. def get_fashion_mnist_labels(labels): #@save
  2. """返回Fashion-MNIST数据集的文本标签"""
  3. text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
  4. 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
  5. return [text_labels[int(i)] for i in labels]

2、完整代码

  1. import torch
  2. import torchvision
  3. import pylab
  4. from torch.utils import data
  5. from torchvision import transforms
  6. import matplotlib.pyplot as plt
  7. from d2l import torch as d2l
  8. import time
  9. batch_size = 256
  10. num_inputs = 784
  11. num_outputs = 10
  12. W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
  13. b = torch.zeros(num_outputs, requires_grad=True)
  14. num_epochs = 5
  15. class Accumulator:
  16. """在n个变量上累加"""
  17. def __init__(self, n):
  18. self.data = [0.0] * n
  19. def add(self, *args):
  20. self.data = [a + float(b) for a, b in zip(self.data, args)]
  21. def reset(self):
  22. self.data = [0.0] * len(self.data)
  23. def __getitem__(self, idx):
  24. return self.data[idx]
  25. def accuracy(y_hat, y): #@save
  26. """计算预测正确的数量"""
  27. if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
  28. y_hat = y_hat.argmax(axis=1)
  29. cmp = y_hat.type(y.dtype) == y
  30. return float(cmp.type(y.dtype).sum())
  31. def cross_entropy(y_hat, y):
  32. return -torch.log(y_hat[range(len(y_hat)), y])
  33. def softmax(X):
  34. X_exp = torch.exp(X)
  35. partition = X_exp.sum(1, keepdim=True)
  36. return X_exp/partition
  37. def net(X):
  38. return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)
  39. def get_dataloader_workers():
  40. """使用一个进程来读取的数据"""
  41. return 1
  42. def get_fashion_mnist_labels(labels):
  43. """返回Fashion-MNIST数据集的文本标签"""
  44. #共10个类别
  45. text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
  46. return [text_labels[int(i)] for i in labels]
  47. def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
  48. """画一系列图片"""
  49. figsize = (num_cols * scale, num_rows * scale)
  50. _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
  51. for i, (img, label) in enumerate(zip(imgs, titles)):
  52. xloc, yloc = i//num_cols, i % num_cols
  53. if torch.is_tensor(img):
  54. # 图片张量
  55. axes[xloc, yloc].imshow(img.reshape((28, 28)).numpy())
  56. else:
  57. # PIL图片
  58. axes[xloc, yloc].imshow(img)
  59. # 设置标题并取消横纵坐标上的刻度
  60. axes[xloc, yloc].set_title(label)
  61. plt.xticks([], ())
  62. axes[xloc, yloc].set_axis_off()
  63. pylab.show()
  64. def load_data_fashion_mnist(batch_size, resize=None):
  65. """下载Fashion-MNIST数据集,然后将其加载到内存中"""
  66. trans = transforms.ToTensor()
  67. if resize:
  68. trans.insert(0, transforms.Resize(resize))
  69. mnist_train = torchvision.datasets.FashionMNIST(root='../data', train=True, transform=trans, download=True)
  70. mnist_test = torchvision.datasets.FashionMNIST(root='../data', train=False, transform=trans, download=True)
  71. return (data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers()),
  72. data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=get_dataloader_workers()))
  73. def evaluate_accuracy(net, data_iter):
  74. """计算在指定数据集上模型的精度"""
  75. if isinstance(net, torch.nn.Module):
  76. net.eval() # 将模型设置为评估模式
  77. metric = Accumulator(2) # 正确预测数、预测总数
  78. with torch.no_grad():
  79. for X, y in data_iter:
  80. metric.add(accuracy(net(X), y), y.numel())
  81. return metric[0] / metric[1]
  82. def updater(batch_size):
  83. lr = 0.1
  84. return d2l.sgd([W, b], lr, batch_size)
  85. def train_epoch_ch3(net, train_iter, loss, updater):
  86. if isinstance(net, torch.nn.Module):
  87. net.train()
  88. metric = Accumulator(3)
  89. for X, y in train_iter:
  90. y_hat = net(X)
  91. lo = loss(y_hat, y)
  92. if isinstance(updater, torch.optim.Optimizer):
  93. updater.zero_grad()
  94. lo.backward()
  95. updater.step()
  96. metric.add(float(lo)*len(y), accuracy(y_hat, y), y.size().numel())
  97. else:
  98. lo.sum().backward()
  99. updater(X.shape[0])
  100. metric.add(float(lo.sum()), accuracy(y_hat, y), y.numel())
  101. return metric[0] / metric[2], metric[1] / metric[2]
  102. class Animator: #@save
  103. """绘制数据"""
  104. def __init__(self, legend=None):
  105. self.legend = legend
  106. self.X = [[], [], []]
  107. self.Y = [[], [], []]
  108. def add(self, x, y):
  109. # 向图表中添加多个数据点
  110. if not hasattr(y, "__len__"):
  111. y = [y]
  112. n = len(y)
  113. if not hasattr(x, "__len__"):
  114. x = [x] * n
  115. for i, (a, b) in enumerate(zip(x, y)):
  116. if a is not None and b is not None:
  117. self.X[i].append(a)
  118. self.Y[i].append(b)
  119. def show(self):
  120. plt.plot(self.X[0], self.Y[0], 'r--')
  121. plt.plot(self.X[1], self.Y[1], 'g--')
  122. plt.plot(self.X[2], self.Y[2], 'b--')
  123. plt.legend(self.legend)
  124. plt.xlabel('epoch')
  125. plt.ylabel('value')
  126. plt.title('Visual')
  127. plt.show()
  128. def train_ch3(net, train_iter, test_iter, loss, num_epochs, updater): #@save
  129. """训练模型"""
  130. animator = Animator(legend=['train loss', 'train acc', 'test acc'])
  131. for epoch in range(num_epochs):
  132. train_metrics = train_epoch_ch3(net, train_iter, loss, updater)
  133. train_loss, train_acc = train_metrics
  134. test_acc = evaluate_accuracy(net, test_iter)
  135. animator.add(epoch + 1, train_metrics + (test_acc,))
  136. print(f'epoch: {epoch+1},train_loss:{train_loss:.4f}, train_acc:{train_acc:.4f}, test_acc:{test_acc:.4f}')
  137. animator.show()
  138. def predict_ch3(net, test_iter, n=12):
  139. """预测标签"""
  140. for X, y in test_iter:
  141. break
  142. trues = d2l.get_fashion_mnist_labels(y)
  143. preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
  144. titles = [true +'\n' + pred for true, pred in zip(trues, preds)]
  145. show_images(
  146. X[0:n].reshape((n, 28, 28)), 2, int(n/2), titles=titles[0:n])
  147. if __name__ == '__main__':
  148. train_iter, test_iter = load_data_fashion_mnist(batch_size)
  149. train_ch3(net, train_iter, test_iter, cross_entropy, num_epochs, updater)
  150. predict_ch3(net, test_iter)

分类效果:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/645113
推荐阅读
  

闽ICP备14008679号