当前位置:   article > 正文

3.AlexNet--CNN经典网络模型详解(pytorch实现)_cnn模型的pytorch代码

cnn模型的pytorch代码

        看博客AlexNet--CNN经典网络模型详解(pytorch实现)_alex的cnn-CSDN博客,该博客的作者写的很详细,是一个简单的目标分类的代码,可以通过该代码深入了解目标检测的简单框架。在这里不作详细的赘述,如果想更深入的了解,可以看另一个博客实现pytorch实现MobileNet-v2(CNN经典网络模型详解) - 知乎 (zhihu.com)

在这里,直接写AlexNet--CNN的代码。

1.首先建立一个model.py文件,用来写神经网络,代码如下:

  1. #model.py
  2. import torch.nn as nn
  3. import torch
  4. class AlexNet(nn.Module):
  5. def __init__(self, num_classes=1000, init_weights=False):
  6. super(AlexNet, self).__init__()
  7. self.features = nn.Sequential( #打包
  8. nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2), # input[3, 224, 224] output[48, 55, 55] 自动舍去小数点后
  9. nn.ReLU(inplace=True), #inplace 可以载入更大模型
  10. nn.MaxPool2d(kernel_size=3, stride=2), # output[48, 27, 27] kernel_num为原论文一半
  11. nn.Conv2d(48, 128, kernel_size=5, padding=2), # output[128, 27, 27]
  12. nn.ReLU(inplace=True),
  13. nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 13, 13]
  14. nn.Conv2d(128, 192, kernel_size=3, padding=1), # output[192, 13, 13]
  15. nn.ReLU(inplace=True),
  16. nn.Conv2d(192, 192, kernel_size=3, padding=1), # output[192, 13, 13]
  17. nn.ReLU(inplace=True),
  18. nn.Conv2d(192, 128, kernel_size=3, padding=1), # output[128, 13, 13]
  19. nn.ReLU(inplace=True),
  20. nn.MaxPool2d(kernel_size=3, stride=2), # output[128, 6, 6]
  21. )
  22. self.classifier = nn.Sequential(
  23. nn.Dropout(p=0.5),
  24. #全链接
  25. nn.Linear(128 * 6 * 6, 2048),
  26. nn.ReLU(inplace=True),
  27. nn.Dropout(p=0.5),
  28. nn.Linear(2048, 2048),
  29. nn.ReLU(inplace=True),
  30. nn.Linear(2048, num_classes),
  31. )
  32. if init_weights:
  33. self._initialize_weights()
  34. def forward(self, x):
  35. x = self.features(x)
  36. x = torch.flatten(x, start_dim=1) #展平 或者view()
  37. x = self.classifier(x)
  38. return x
  39. def _initialize_weights(self):
  40. for m in self.modules():
  41. if isinstance(m, nn.Conv2d):
  42. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法
  43. if m.bias is not None:
  44. nn.init.constant_(m.bias, 0)
  45. elif isinstance(m, nn.Linear):
  46. nn.init.normal_(m.weight, 0, 0.01) #正态分布赋值
  47. nn.init.constant_(m.bias, 0)

2.下载数据集

DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'

3.下载完后写一个spile_data.py文件,将数据集进行分类

  1. #spile_data.py
  2. import os
  3. from shutil import copy
  4. import random
  5. def mkfile(file):
  6. if not os.path.exists(file):
  7. os.makedirs(file)
  8. file = 'flower_data/flower_photos'
  9. flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
  10. mkfile('flower_data/train')
  11. for cla in flower_class:
  12. mkfile('flower_data/train/'+cla)
  13. mkfile('flower_data/val')
  14. for cla in flower_class:
  15. mkfile('flower_data/val/'+cla)
  16. split_rate = 0.1
  17. for cla in flower_class:
  18. cla_path = file + '/' + cla + '/'
  19. images = os.listdir(cla_path)
  20. num = len(images)
  21. eval_index = random.sample(images, k=int(num*split_rate))
  22. for index, image in enumerate(images):
  23. if image in eval_index:
  24. image_path = cla_path + image
  25. new_path = 'flower_data/val/' + cla
  26. copy(image_path, new_path)
  27. else:
  28. image_path = cla_path + image
  29. new_path = 'flower_data/train/' + cla
  30. copy(image_path, new_path)
  31. print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="") # processing bar
  32. print()
  33. print("processing done!")

之后应该是这样:
在这里插入图片描述

4.再写一个train.py文件,用来训练模型

  1. import torch
  2. import torch.nn as nn
  3. from torchvision import transforms, datasets, utils
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import torch.optim as optim
  7. from model import AlexNet
  8. import os
  9. import json
  10. import time
  11. #device : GPU or CPU
  12. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  13. #数据转换
  14. data_transform = {
  15. #具体是对图像进行各种转换操作,并用函数compose将这些转换操作组合起来
  16. #以下操作步骤:
  17. # 1.图片随机裁剪为224X224
  18. # 2.随机水平旋转,默认为概率0.5
  19. # 3.将给定图像转为Tensor
  20. # 4.归一化处理
  21. "train": transforms.Compose([transforms.RandomResizedCrop(224),
  22. transforms.RandomHorizontalFlip(),
  23. transforms.ToTensor(),
  24. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
  25. "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
  26. transforms.ToTensor(),
  27. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
  28. #===========================================================================================
  29. #train_dataset[1][0].size()===>torch.Size([3, 224, 224])
  30. #train_dataset.imgs[0]===>('D:\\AI\\yq\\mubiao_detect\\ziji/flower_data//train\\daisy\\100080576_f52e8ee070_n.jpg', 0)
  31. #===========================================================================================
  32. data_root = os.getcwd()
  33. image_path = data_root + "/flower_data/" # flower data set path
  34. train_dataset = datasets.ImageFolder(root=image_path + "/train",
  35. transform=data_transform["train"])
  36. train_num = len(train_dataset)
  37. #===========================================================================================
  38. #flower_list: {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
  39. #cla_dict: {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
  40. #===========================================================================================
  41. flower_list = train_dataset.class_to_idx
  42. cla_dict = dict((val, key) for key, val in flower_list.items())
  43. # write dict into json file
  44. json_str = json.dumps(cla_dict, indent=4)
  45. with open('class_indices.json', 'w') as json_file:
  46. json_file.write(json_str)
  47. #<torch.utils.data.dataloader.DataLoader object at 0x000001BAA346C670>
  48. #=====>torch.Size([32, 3, 224, 224]) torch.Size([32])<====== 每个batch为32,网络大小为3X32X32,每张图片一个类
  49. batch_size = 32
  50. train_loader = torch.utils.data.DataLoader(train_dataset,
  51. batch_size=batch_size, shuffle=True,
  52. num_workers=0)
  53. net = AlexNet(num_classes=5, init_weights=True)#AlexNet神经网络
  54. net.to(device)#用gpu
  55. #损失函数:这里用交叉熵
  56. loss_function = nn.CrossEntropyLoss()
  57. #优化器 这里用Adam
  58. optimizer = optim.Adam(net.parameters(), lr=0.0002)
  59. #训练参数保存路径
  60. save_path = './AlexNet.pth'
  61. #训练过程中最高准确率
  62. best_acc = 0.0
  63. #开始进行训练和测试,训练一轮,测试一轮
  64. for epoch in range(10):
  65. # train
  66. net.train() #训练过程中,使用之前定义网络中的dropout
  67. running_loss = 0.0
  68. t1 = time.perf_counter()
  69. for step, data in enumerate(train_loader, start=0):
  70. images, labels = data
  71. optimizer.zero_grad()#将梯度归零
  72. outputs = net(images.to(device))#图片经过神经网络后的输出, torch.Size([32, 5])===>32张图片,每张图片有5种可能的类型(batch,classes)
  73. loss = loss_function(outputs, labels.to(device))#用交叉熵的损失函数
  74. loss.backward()#反向传播计算得到每个参数的梯度值
  75. optimizer.step()#通过梯度下降执行一步参数更新
  76. # print statistics
  77. running_loss += loss.item()
  78. rate = (step + 1) / len(train_loader)
  79. a = "*" * int(rate * 50)
  80. b = "." * int((1 - rate) * 50)
  81. print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
  82. print()
  83. print(time.perf_counter()-t1)
  84. print('Finished Training')

5.写一个预测的predict.py文件

代码如下:

  1. import torch
  2. from model import AlexNet
  3. from PIL import Image
  4. from torchvision import transforms
  5. import matplotlib.pyplot as plt
  6. import json
  7. data_transform = transforms.Compose(
  8. [transforms.Resize((224, 224)),
  9. transforms.ToTensor(),
  10. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
  11. # load image
  12. img = Image.open("./sunflower.jpg") #验证太阳花
  13. #img = Image.open("./roses.jpg") #验证玫瑰花
  14. plt.imshow(img)
  15. # [N, C, H, W]
  16. img = data_transform(img)
  17. # expand batch dimension
  18. img = torch.unsqueeze(img, dim=0)
  19. # read class_indict
  20. try:
  21. json_file = open('./class_indices.json', 'r')
  22. class_indict = json.load(json_file)
  23. except Exception as e:
  24. print(e)
  25. exit(-1)
  26. # create model
  27. model = AlexNet(num_classes=5)
  28. # load model weights
  29. model_weight_path = "./AlexNet.pth"
  30. model.load_state_dict(torch.load(model_weight_path))
  31. model.eval()
  32. with torch.no_grad():
  33. # predict class
  34. output = torch.squeeze(model(img))
  35. predict = torch.softmax(output, dim=0)
  36. predict_cla = torch.argmax(predict).numpy()
  37. print(class_indict[str(predict_cla)], predict[predict_cla].item())
  38. plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/789946
推荐阅读
相关标签
  

闽ICP备14008679号