当前位置:   article > 正文

基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

 第一步:准备数据

5种中草药数据:self.class_indict = ["百合", "党参", "山魈", "枸杞", "槐花", "金银花"]

,总共有900张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个EfficientNetV2网络,其原理介绍如下:

        该网络主要使用训练感知神经结构搜索缩放的组合;在EfficientNetV1的基础上,引入了Fused-MBConv到搜索空间中;引入渐进式学习策略自适应正则强度调整机制使得训练更快;进一步关注模型的推理速度训练速度

与EfficientV1相比,主要有以下不同:

  1. V2中除了使用MBConv模块外,还使用了Fused-MBConv模块
  2. V2中会使用较小的expansion ratio,在V1中基本都是6。这样的好处是能够减少内存访问开销
  3. V2中更偏向使用更小的kernel_size(3 x 3),在V1中很多5 x 5。优于3 x 3的感受野是比5 x 5小的,所以需要堆叠更多的层结构以增加感受野
  4. 移除了V1中最优一个步距为1的stage

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

  1. import os
  2. import math
  3. import argparse
  4. import torch
  5. import torch.optim as optim
  6. from torch.utils.tensorboard import SummaryWriter
  7. from torchvision import transforms
  8. import torch.optim.lr_scheduler as lr_scheduler
  9. from model import efficientnetv2_s as create_model
  10. from my_dataset import MyDataSet
  11. from utils import read_split_data, train_one_epoch, evaluate
  12. def main(args):
  13. device = torch.device(args.device if torch.cuda.is_available() else "cpu")
  14. print(args)
  15. print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
  16. tb_writer = SummaryWriter()
  17. if os.path.exists("./weights") is False:
  18. os.makedirs("./weights")
  19. train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)
  20. img_size = {"s": [300, 384], # train_size, val_size
  21. "m": [384, 480],
  22. "l": [384, 480]}
  23. num_model = "s"
  24. data_transform = {
  25. "train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),
  26. transforms.RandomHorizontalFlip(),
  27. transforms.ToTensor(),
  28. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
  29. "val": transforms.Compose([transforms.Resize(img_size[num_model][1]),
  30. transforms.CenterCrop(img_size[num_model][1]),
  31. transforms.ToTensor(),
  32. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}
  33. # 实例化训练数据集
  34. train_dataset = MyDataSet(images_path=train_images_path,
  35. images_class=train_images_label,
  36. transform=data_transform["train"])
  37. # 实例化验证数据集
  38. val_dataset = MyDataSet(images_path=val_images_path,
  39. images_class=val_images_label,
  40. transform=data_transform["val"])
  41. batch_size = args.batch_size
  42. nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
  43. print('Using {} dataloader workers every process'.format(nw))
  44. train_loader = torch.utils.data.DataLoader(train_dataset,
  45. batch_size=batch_size,
  46. shuffle=True,
  47. pin_memory=True,
  48. num_workers=nw,
  49. collate_fn=train_dataset.collate_fn)
  50. val_loader = torch.utils.data.DataLoader(val_dataset,
  51. batch_size=batch_size,
  52. shuffle=False,
  53. pin_memory=True,
  54. num_workers=nw,
  55. collate_fn=val_dataset.collate_fn)
  56. # 如果存在预训练权重则载入
  57. model = create_model(num_classes=args.num_classes).to(device)
  58. if args.weights != "":
  59. if os.path.exists(args.weights):
  60. weights_dict = torch.load(args.weights, map_location=device)
  61. load_weights_dict = {k: v for k, v in weights_dict.items()
  62. if model.state_dict()[k].numel() == v.numel()}
  63. print(model.load_state_dict(load_weights_dict, strict=False))
  64. else:
  65. raise FileNotFoundError("not found weights file: {}".format(args.weights))
  66. # 是否冻结权重
  67. if args.freeze_layers:
  68. for name, para in model.named_parameters():
  69. # 除head外,其他权重全部冻结
  70. if "head" not in name:
  71. para.requires_grad_(False)
  72. else:
  73. print("training {}".format(name))
  74. pg = [p for p in model.parameters() if p.requires_grad]
  75. optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)
  76. # Scheduler https://arxiv.org/pdf/1812.01187.pdf
  77. lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
  78. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  79. for epoch in range(args.epochs):
  80. # train
  81. train_loss, train_acc = train_one_epoch(model=model,
  82. optimizer=optimizer,
  83. data_loader=train_loader,
  84. device=device,
  85. epoch=epoch)
  86. scheduler.step()
  87. # validate
  88. val_loss, val_acc = evaluate(model=model,
  89. data_loader=val_loader,
  90. device=device,
  91. epoch=epoch)
  92. tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
  93. tb_writer.add_scalar(tags[0], train_loss, epoch)
  94. tb_writer.add_scalar(tags[1], train_acc, epoch)
  95. tb_writer.add_scalar(tags[2], val_loss, epoch)
  96. tb_writer.add_scalar(tags[3], val_acc, epoch)
  97. tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)
  98. torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))
  99. if __name__ == '__main__':
  100. parser = argparse.ArgumentParser()
  101. parser.add_argument('--num_classes', type=int, default=5)
  102. parser.add_argument('--epochs', type=int, default=100)
  103. parser.add_argument('--batch-size', type=int, default=4)
  104. parser.add_argument('--lr', type=float, default=0.01)
  105. parser.add_argument('--lrf', type=float, default=0.01)
  106. # 数据集所在根目录
  107. # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
  108. parser.add_argument('--data-path', type=str,
  109. default=r"G:\demo\data\ChineseMedicine")
  110. # download model weights
  111. # 链接: https://pan.baidu.com/s/1uZX36rvrfEss-JGj4yfzbQ 密码: 5gu1
  112. parser.add_argument('--weights', type=str, default='./pre_efficientnetv2-s.pth',
  113. help='initial weights path')
  114. parser.add_argument('--freeze-layers', type=bool, default=True)
  115. parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
  116. opt = parser.parse_args()
  117. main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

有问题可以私信或者留言,有问必答

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/670740
推荐阅读
相关标签
  

闽ICP备14008679号