当前位置:   article > 正文

李沐-动手学深度学习-LeNet_李沐深度学习实验

李沐深度学习实验

1.LeNet的实现

  1. import torch
  2. from torch import nn
  3. from d2l import torch as d2l
  4. net = nn.Sequential(
  5. nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
  6. nn.AvgPool2d(kernel_size=2,stride=2),
  7. nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
  8. nn.AvgPool2d(kernel_size=2,stride=2),
  9. nn.Flatten(),
  10. nn.Linear(16*5*5,120),nn.Sigmoid(),
  11. nn.Linear(120,84),nn.Sigmoid(),
  12. nn.Linear(84,10))
  13. X = torch.rand(size=(1,1,28,28),dtype=torch.float32)
  14. for layer in net:
  15. X = layer(X)
  16. print(layer.__class__.__name__,'output shape: \t',X.shape)

2.模型训练

LeNet在Fashion-MNIST数据集上的表现,all代码

  1. import torch
  2. from torch import nn
  3. from d2l import torch as d2l
  4. net = nn.Sequential(
  5. nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
  6. nn.AvgPool2d(kernel_size=2,stride=2),
  7. nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
  8. nn.AvgPool2d(kernel_size=2,stride=2),
  9. nn.Flatten(),
  10. nn.Linear(16*5*5,120),nn.Sigmoid(),
  11. nn.Linear(120,84),nn.Sigmoid(),
  12. nn.Linear(84,10))
  13. X = torch.rand(size=(1,1,28,28),dtype=torch.float32)
  14. for layer in net:
  15. X = layer(X)
  16. print(layer.__class__.__name__,'output shape: \t',X.shape)
  17. batch_size = 256
  18. train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
  19. #对 3.6节中描述的evaluate_accuracy函数进行轻微的修改
  20. def evaluate_accuracy_gpu(net,data_iter,device=None):
  21. """使用GPU计算模型在数据集上的精度"""
  22. if isinstance(net,nn.Module):
  23. net.eval() #设置为评估模式
  24. if not device:
  25. device = next(iter(net.parameters())).device
  26. #正确预测的数量,总预测的数量
  27. metric = d2l.Accumulator(2)
  28. with torch.no_grad():
  29. for X,y in data_iter:
  30. if isinstance(X,list):
  31. #BERT微调所需的(之后介绍)
  32. X = [x.to(device) for x in X]
  33. else:
  34. X = X.to(device)
  35. y = y.to(device)
  36. metric.add(d2l.accuracy(net(X),y),y.numel())
  37. return metric[0] / metric[1]
  38. #与 3.6节中定义的train_epoch_ch3不同,在进行正向和反向传播之前,将每一小批量数据移动到指定的设备(例如GPU)上
  39. #训练函数train_ch6也类似于 3.6节中定义的train_ch3
  40. #以下训练函数假定从高级API创建的模型作为输入,并进行相应的优化。
  41. # 使用在 4.8.2.2节中介绍的Xavier随机初始化模型参数。与全连接层一样,使用交叉熵损失函数和小批量随机梯度下降。
  42. #@save
  43. def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):
  44. """用GPU训练模型(在第六章定义)"""
  45. def init_weights(m):
  46. if type(m) == nn.Linear or type(m) == nn.Conv2d:
  47. nn.init.xavier_uniform_(m.weight)
  48. net.apply(init_weights)
  49. print('training on', device)
  50. net.to(device)
  51. optimizer = torch.optim.SGD(net.parameters(), lr=lr)
  52. loss = nn.CrossEntropyLoss()
  53. animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],
  54. legend=['train loss', 'train acc', 'test acc'])
  55. timer, num_batches = d2l.Timer(), len(train_iter)
  56. for epoch in range(num_epochs):
  57. # 训练损失之和,训练准确率之和,样本数
  58. metric = d2l.Accumulator(3)
  59. net.train()
  60. for i, (X, y) in enumerate(train_iter):
  61. timer.start()
  62. optimizer.zero_grad()
  63. X, y = X.to(device), y.to(device)
  64. y_hat = net(X)
  65. l = loss(y_hat, y)
  66. l.backward()
  67. optimizer.step()
  68. with torch.no_grad():
  69. metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])
  70. timer.stop()
  71. train_l = metric[0] / metric[2]
  72. train_acc = metric[1] / metric[2]
  73. if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
  74. animator.add(epoch + (i + 1) / num_batches,
  75. (train_l, train_acc, None))
  76. test_acc = evaluate_accuracy_gpu(net, test_iter)
  77. animator.add(epoch + 1, (None, None, test_acc))
  78. print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, '
  79. f'test acc {test_acc:.3f}')
  80. print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec '
  81. f'on {str(device)}')
  82. lr,num_epochs = 0.9,10
  83. print(train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu()))
  84. d2l.plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/939107
推荐阅读
相关标签
  

闽ICP备14008679号