当前位置:   article > 正文

pytorch基础教学简单实例(附代码)_pytorch测试代码

pytorch测试代码

目录

1.简介 

2.数据集

3.模型初始化

4.训练参数

5.训练&验证

6.保存&加载模型


1.简介 

这篇文章主要是针对刚入门pytorch的小伙伴,会带大家完整走一遍使用神经网络训练的流程,以及介绍一些pytorch常用的函数。如果还未安装pytorch或者安装有困难,可以参考我的上一篇文章:

Windows Anaconda精简安装cuda+pytorch+torchvision


2.数据集

这里使用的是FashionMNIST,因为可以直接调用pytorch里的代码下载比较方便。加载代码如下:

  1. import torch
  2. from torch import nn
  3. from torch.utils.data import DataLoader
  4. from torchvision import datasets
  5. from torchvision.transforms import ToTensor
  6. # ----------数据集----------
  7. # 加载MNIST数据集的训练集
  8. training_data = datasets.FashionMNIST(
  9. root="data",
  10. train=True,
  11. download=True,
  12. transform=ToTensor(),
  13. )
  14. # 加载MNIST数据集的测试集
  15. test_data = datasets.FashionMNIST(
  16. root="data",
  17. train=False,
  18. download=True,
  19. transform=ToTensor(),
  20. )
  21. # batch大小
  22. batch_size = 64
  23. # 创建dataloader
  24. train_dataloader = DataLoader(training_data, batch_size=batch_size)
  25. test_dataloader = DataLoader(test_data, batch_size=batch_size)
  26. # 遍历dataloader
  27. for X, y in test_dataloader:
  28. print("Shape of X [N, C, H, W]: ", X.shape) # 每个batch数据的形状
  29. print("Shape of y: ", y.shape) # 每个batch标签的形状
  30. break


3.模型初始化

手动搭建了两个全连接层的神经网络结构。代码如下:

  1. # ----------模型----------
  2. # 定义模型
  3. class NeuralNetwork(nn.Module):
  4. def __init__(self): # 初始化,实例化模型的时候就会调用
  5. super(NeuralNetwork, self).__init__()
  6. self.flatten = nn.Flatten() # [64, 1, 28, 28] -> [64, 1*28*28]
  7. self.linear_relu_stack = nn.Sequential(
  8. nn.Linear(28*28, 512), # [64, 1*28*28] -> [64, 512]
  9. nn.ReLU(),
  10. nn.Linear(512, 512), # [64, 512] -> [64, 512]
  11. nn.ReLU(),
  12. nn.Linear(512, 10) # [64, 512] -> [64, 10]
  13. )
  14. def forward(self, x): # 前向传播,输入数据进网络的时候才会调用
  15. x = self.flatten(x) # [64, 1*28*28]
  16. logits = self.linear_relu_stack(x) # [64, 10]
  17. return logits
  18. # 使用gpu或者cpu进行训练
  19. device = "cuda" if torch.cuda.is_available() else "cpu"
  20. # 打印使用的是gpu/cpu
  21. print("Using {} device".format(device))
  22. # 实例化模型
  23. model = NeuralNetwork().to(device)
  24. # 打印模型结构
  25. print(model)

4.训练参数

损失函数用来计算预测输出和真实输出的差值,优化器用来更新网络的权重参数,scheduler用来调整训练过程中的学习率。代码如下:

  1. # ----------训练参数设置----------
  2. loss_fn = nn.CrossEntropyLoss() # 损失函数设置
  3. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) # 学习率设置
  4. epochs = 5 # 训练迭代次数设置

5.训练&验证

训练和测试的步骤可以总结为下图,具体代码如下:

  1. # 训练函数
  2. def train(train_dataloader, model, loss_fn, optimizer):
  3. """
  4. 训练网络
  5. 输入:
  6. train_dataloader: 训练集的dataloader
  7. model: 网络模型
  8. loss_fn: 损失函数
  9. optimizer: 优化器
  10. """
  11. # 切换到train模式
  12. model.train()
  13. # 遍历dataloader
  14. for images, labels in train_dataloader:
  15. # 将数据和标签加载到device上
  16. images, labels = images.to(device), labels.to(device)
  17. # 输入数据到模型里得到输出
  18. pred = model(images)
  19. # 计算输出和标签的loss
  20. loss = loss_fn(pred, labels)
  21. # 反向推导
  22. optimizer.zero_grad()
  23. loss.backward()
  24. # 步进优化器
  25. optimizer.step()
  26. # 测试函数
  27. def test(test_dataloader, model, loss_fn):
  28. """
  29. 测试网络
  30. 输入:
  31. test_dataloader: 测试集的dataloader
  32. model: 网络模型
  33. loss_fn: 损失函数
  34. """
  35. # 测试集大小
  36. size = len(test_dataloader.dataset)
  37. # 测试集的batch数量
  38. num_batches = len(test_dataloader)
  39. # 切换到测试模型
  40. model.eval()
  41. # 记录loss和准确率
  42. test_loss, correct = 0, 0
  43. # 梯度截断
  44. with torch.no_grad():
  45. for images, labels in test_dataloader: # 遍历batch
  46. # 加载到device
  47. images, labels = images.to(device), labels.to(device)
  48. # 输入数据到模型里得到输出
  49. pred = model(images)
  50. # 累加loss
  51. test_loss += loss_fn(pred, labels).item()
  52. # 累加正确率
  53. correct += (pred.argmax(1) == labels).sum().item()
  54. # 计算平均loss和准确率
  55. test_loss /= num_batches
  56. correct /= size
  57. print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

6.保存&加载模型

在训练好网络后,我们可以将训练好的权重保存下来,这样下次要使用的时候就不用再次训练,直接加载就好,代码如下:

  1. # 保存模型
  2. torch.save(model.state_dict(), "model.pth")
  3. # 加载模型
  4. model = NeuralNetwork()
  5. model.load_state_dict(torch.load("model.pth"))

业务合作/学习交流+v:lizhiTechnology

如果想要了解更多深度学习相关知识,可以参考我的其他文章:

【优化器】(一) SGD原理 & pytorch代码解析_sgd优化器-CSDN博客

【损失函数】(一) L1Loss原理 & pytorch代码解析_l1 loss-CSDN博客

【图像生成】(一) DNN 原理 & pytorch代码实例_pytorch dnn代码-CSDN博客

深度学习_Lcm_Tech的博客-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/263781
推荐阅读
相关标签
  

闽ICP备14008679号