当前位置:   article > 正文

Pytorch搭建AlexNet目标检测平台_pytorch alexnet gpu

pytorch alexnet gpu

目录

一、AlexNet网络介绍

1.网络结构

Alexnet模型由5个卷积层和3个池化Pooling 层 ,其中还有3个全连接层构成。AlexNet 跟 LeNet 结构类似,但使用了更多的卷积层和更大的参数空间来拟合大规模数据集 ImageNet。它是浅层神经网络和深度神经网络的分界线。

AlexNet网络结构

2. 特点:

(1)在每个卷机后面添加了ReLU激活函数,解决了Sigmoid的梯度消失问题,使收敛更快;

(2)使用随机丢弃技术(Dropout)选择性地忽略训练中的单个神经元,避免模型的过拟合(也使用数据增强防止过拟合);

(3)添加了归一化LRN(Local Response Normalization,局部响应归一化)层,使准确率更高;

(4)重叠最大池化Overlapping max pooling),即池化范围 z 与步长 s 存在关系 z>s 避免平均池化(average pooling)的平均效应。

二. 准备

有关AlexNet网络的源码我放在了百度网盘了:

链接:https://pan.baidu.com/s/1k_Rbb27ksykMdpnc-0P9zQ 
提取码:0x2s

Vscode中的文件结构

 

三. 程序实现

(1) AlexNet网络结构model.py

  1. import torch.nn as nn
  2. import torch
  3. class AlexNet(nn.Module):
  4. def __init__(self, num_classes=1000, init_weights=False):
  5. super(AlexNet, self).__init__()
  6. self.features = nn.Sequential(
  7. nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55]
  8. nn.ReLU(inplace=True),
  9. nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27]
  10. nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]
  11. nn.ReLU(inplace=True),
  12. nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]
  13. nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]
  14. nn.ReLU(inplace=True),
  15. nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]
  16. nn.ReLU(inplace=True),
  17. nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]
  18. nn.ReLU(inplace=True),
  19. nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6]
  20. )
  21. self.classifier = nn.Sequential(
  22. nn.Dropout(p=0.5),
  23. nn.Linear(128 * 6 * 6, 2048),
  24. nn.ReLU(inplace=True),
  25. nn.Dropout(p=0.5),
  26. nn.Linear(2048, 2048),
  27. nn.ReLU(inplace=True),
  28. nn.Linear(2048, num_classes),
  29. )
  30. if init_weights:
  31. self._initialize_weights()
  32. def forward(self, x):
  33. x = self.features(x)
  34. x = torch.flatten(x, start_dim=1)
  35. x = self.classifier(x)
  36. return x
  37. def _initialize_weights(self):
  38. for m in self.modules():
  39. if isinstance(m, nn.Conv2d):
  40. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  41. if m.bias is not None:
  42. nn.init.constant_(m.bias, 0)
  43. elif isinstance(m, nn.Linear):
  44. nn.init.normal_(m.weight, 0, 0.01)
  45. nn.init.constant_(m.bias, 0)

(2) 训练网络train.py

  1. import os
  2. import sys
  3. import json
  4. import torch
  5. import torch.nn as nn
  6. from torchvision import transforms, datasets, utils
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import torch.optim as optim
  10. from tqdm import tqdm
  11. from model import AlexNet
  12. def main():
  13. #判断是否使用GPU设备
  14. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  15. print("using {} device.".format(device))
  16. data_transform = {
  17. #训练数据集合
  18. "train": transforms.Compose([transforms.RandomResizedCrop(224),#随机裁剪
  19. transforms.RandomHorizontalFlip(),#随机翻转
  20. transforms.ToTensor(), #转化成tensor
  21. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
  22. #验证数据集合
  23. "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
  24. transforms.ToTensor(),
  25. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
  26. #获取数据集地址
  27. image_path = './flower_data'
  28. assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
  29. #root = ……表示加载数据集合的路径
  30. train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
  31. transform=data_transform["train"]) #transform表示数据预处理
  32. #打印数据集合的图片个数
  33. train_num = len(train_dataset)
  34. #获取分类名称所对应的索引
  35. # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
  36. flower_list = train_dataset.class_to_idx
  37. cla_dict = dict((val, key) for key, val in flower_list.items())
  38. # write dict into json file
  39. json_str = json.dumps(cla_dict, indent=4)
  40. with open('class_indices.json', 'w') as json_file:
  41. json_file.write(json_str)
  42. batch_size = 32
  43. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
  44. print('Using {} dataloader workers every process'.format(nw))
  45. train_loader = torch.utils.data.DataLoader(train_dataset,
  46. batch_size=batch_size, shuffle=True,
  47. num_workers=nw)
  48. validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
  49. transform=data_transform["val"])
  50. val_num = len(validate_dataset)
  51. validate_loader = torch.utils.data.DataLoader(validate_dataset,
  52. batch_size=4, shuffle=False,
  53. num_workers=nw)
  54. print("using {} images for training, {} images for validation.".format(train_num,
  55. val_num))
  56. net = AlexNet(num_classes=5, init_weights=True)
  57. net.to(device)
  58. loss_function = nn.CrossEntropyLoss()
  59. # pata = list(net.parameters())
  60. optimizer = optim.Adam(net.parameters(), lr=0.0002)
  61. epochs = 5
  62. save_path = './AlexNet.pth'
  63. best_acc = 0.0
  64. train_steps = len(train_loader)
  65. for epoch in range(epochs):
  66. # train
  67. net.train()
  68. running_loss = 0.0
  69. train_bar = tqdm(train_loader, file=sys.stdout)
  70. for step, data in enumerate(train_bar):
  71. images, labels = data
  72. optimizer.zero_grad()
  73. outputs = net(images.to(device))
  74. loss = loss_function(outputs, labels.to(device))
  75. loss.backward()
  76. optimizer.step()
  77. # print statistics
  78. running_loss += loss.item()
  79. train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
  80. epochs,
  81. loss)
  82. # validate
  83. net.eval()
  84. acc = 0.0 # accumulate accurate number / epoch
  85. with torch.no_grad():
  86. val_bar = tqdm(validate_loader, file=sys.stdout)
  87. for val_data in val_bar:
  88. val_images, val_labels = val_data
  89. outputs = net(val_images.to(device))
  90. predict_y = torch.max(outputs, dim=1)[1]
  91. acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
  92. val_accurate = acc / val_num
  93. print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
  94. (epoch + 1, running_loss / train_steps, val_accurate))
  95. if val_accurate > best_acc:
  96. best_acc = val_accurate
  97. torch.save(net.state_dict(), save_path)
  98. print('Finished Training')
  99. if __name__ == '__main__':
  100. main()

这里只训练了5次的结果,可以进行多次训练,得到比较好的精度:

(3)预测 predict.py

  1. import os
  2. import json
  3. import torch
  4. from PIL import Image
  5. from torchvision import transforms
  6. import matplotlib.pyplot as plt
  7. from model import AlexNet
  8. def main():
  9. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  10. data_transform = transforms.Compose(
  11. [transforms.Resize((224, 224)),
  12. transforms.ToTensor(),
  13. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  14. # load image
  15. img_path = "../tulip.jpg"
  16. assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
  17. img = Image.open(img_path)
  18. plt.imshow(img)
  19. # [N, C, H, W]
  20. img = data_transform(img)
  21. # expand batch dimension
  22. img = torch.unsqueeze(img, dim=0)
  23. # read class_indict
  24. json_path = './class_indices.json'
  25. assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
  26. with open(json_path, "r") as f:
  27. class_indict = json.load(f)
  28. # create model
  29. model = AlexNet(num_classes=5).to(device)
  30. # load model weights
  31. weights_path = "./AlexNet.pth"
  32. assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
  33. model.load_state_dict(torch.load(weights_path))
  34. model.eval()
  35. with torch.no_grad():
  36. # predict class
  37. output = torch.squeeze(model(img.to(device))).cpu()
  38. predict = torch.softmax(output, dim=0)
  39. predict_cla = torch.argmax(predict).numpy()
  40. print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
  41. predict[predict_cla].numpy())
  42. plt.title(print_res)
  43. for i in range(len(predict)):
  44. print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
  45. predict[i].numpy()))
  46. plt.show()
  47. if __name__ == '__main__':
  48. main()

四. 结果

预测结果:识别出向日葵的置信度为:0.894

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/92615
推荐阅读
相关标签
  

闽ICP备14008679号