赞
踩
使用卷积神经网络解决手写数字识别问题
视频链接:《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
先讲一下卷积神经网络的基本概念(推荐去看《深度学习入门》这本书)
卷积神经网络与全连接层的神经网络的优势在于保持了图像的空间信息,没有压缩图像的维度,这样训练的时候可以会训练到全连接层所没有关注到的图像信息,提高泛化能力。
卷积神经网络有卷积层和池化层:
下面来看我们具体的模型实现:
这两张图非常重要,一定要理解哦,特别是输入数据与卷积层和输出数据之间的各维数之间的关系。
先看一个小例子:
- import torch
-
- in_channels, out_channels = 5, 10
- width, height = 100, 100
- kernel_size = 3
- batch_size = 1
-
- input = torch.randn(batch_size,
- in_channels,
- width,
- height)
-
- conv_layer = torch.nn.Conv2d(in_channels,
- out_channels,
- kernel_size=kernel_size)
- output = conv_layer(input)
- print(input.shape) # 输入形状
- print(output.shape) # 输出形状
- print(conv_layer.weight.shape) # 滤波器的形状
""" torch.Size([1, 5, 100, 100]) # [batch_size,in_channles,width,height] torch.Size([1, 10, 98, 98]) # [batch_size,out_channles,width,height] torch.Size([10, 5, 3, 3]) # [out_channels,in_channels,kernel_size,kernel_size] """
其实也就是说,输入通道数=卷积核的通道个数,输出通道数=卷积核的个数
好了,下面看总的实现代码把
- from torchvision.datasets import MNIST
- from torchvision import transforms
- from torch.utils.data import DataLoader
- import torch
- import torch.nn.functional as F
-
- # 利用卷积神经网络解决MNIST手写数字识别
- # 1、准备数据集
- # 处理数据
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307,), (0.3081,))
- ])
- batch_size = 64
- # 训练集
- mnist_train = MNIST(root='../dataset/mnist', train=True, transform=transform, download=True)
- train_loader = DataLoader(dataset=mnist_train, shuffle=True, batch_size=batch_size)
- # 测试集
- mnist_test = MNIST(root='../dataset/mnist', train=False, transform=transform, download=True)
- test_loader = DataLoader(dataset=mnist_test, shuffle=True, batch_size=batch_size)
-
-
- # 2.设计模型类
- class Net(torch.nn.Module):
- def __init__(self):
- super(Net, self).__init__()
- self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
- self.pooling = torch.nn.MaxPool2d(2)
- self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
- self.fc = torch.nn.Linear(320, 10) # 最后使用全连接层,分类的类别为10个
-
- def forward(self, x):
- batch_size = x.size(0)
- x = self.pooling(F.relu(self.conv1(x))) # 先卷积,再激活,再池化
- x = self.pooling(F.relu(self.conv2(x)))
- # 全连接层,将x[batch_size,20,4,4]->x[batch,20*4*4] 全连接层只能接受一维的数据
- x = x.view(-1, 320) # 或者写成 x = x.view(batch_size,-1)
- x = self.fc(x)
- return x
-
-
- model = Net()
- # 3、构造损失函数和优化器
- criterion = torch.nn.CrossEntropyLoss()
- optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
-
-
- # 4、训练和测试
- # 定义训练方法,一个训练周期
- def train(epoch):
- running_loss = 0.0
- for idx, (inputs, target) in enumerate(train_loader, 0):
- # 这里的代码与之前没有区别
- # 正向
- y_pred = model(inputs)
- loss = criterion(y_pred, target)
- # 反向
- optimizer.zero_grad()
- loss.backward()
- # 更新
- optimizer.step()
-
- running_loss += loss.item()
- if idx % 300 == 299: # 每300次打印一次平均损失,因为idx是从0开始的,所以%299,而不是300
- print(f'epoch={epoch + 1},batch_idx={idx + 1},loss={running_loss / 300}')
- running_loss = 0.0
-
-
- # 定义测试方法,一个测试周期
- def test():
- # 所有预测正确的样本数
- correct_num = 0
- # 所有样本的数量
- total = 0
- # 测试时,我们不需要计算梯度,因此可以加上这一句,不需要梯度追踪
- with torch.no_grad():
- for images, labels in test_loader:
- # 获得预测值
- outputs = model(images)
- # 获取dim=1的最大值的位置,该位置就代表所预测的标签值
- _, predicted = torch.max(outputs.data, dim=1)
- # 累加每批次的样本数,以获得一个测试周期所有的样本数
- total += labels.size(0)
- # 累加每批次的预测正确的样本数,以获得一个测试周期的所有预测正确的样本数
- correct_num += (predicted == labels).sum().item()
- print(f'Accuracy on test set:{100 * correct_num / total}%') # 打印一个测试周期的正确率
-
-
- if __name__ == '__main__':
- # 训练周期为10次,每次训练所有的训练集样本数,并测试
- for epoch in range(10):
- train(epoch)
- test()
结果如下:
epoch=1,batch_idx=300,loss=0.6427036832769711
epoch=1,batch_idx=600,loss=0.20884355887770653
epoch=1,batch_idx=900,loss=0.150776317777733
Accuracy on test set:96.73%
epoch=2,batch_idx=300,loss=0.12021432491019368
epoch=2,batch_idx=600,loss=0.10932956301607191
epoch=2,batch_idx=900,loss=0.09082858871979017
Accuracy on test set:97.77%
........
epoch=9,batch_idx=300,loss=0.037825182913220484
epoch=9,batch_idx=600,loss=0.04158504081889987
epoch=9,batch_idx=900,loss=0.03867125011437262
Accuracy on test set:98.72%
epoch=10,batch_idx=300,loss=0.03946270007213267
epoch=10,batch_idx=600,loss=0.036561023951120056
epoch=10,batch_idx=900,loss=0.034661959860629095
Accuracy on test set:98.64%
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。