当前位置:   article > 正文

MindSpore端侧手机应用实战:AI垃圾分类应用_基于mindspore端侧的ai 垃圾分类应用结果图

基于mindspore端侧的ai 垃圾分类应用结果图

本文来源于:知乎

作者:李锐锋

MindSpore作为一个端边云协同的开源的全场景AI框架,今年3月份开源以来,受到了业界的广泛关注和应用,Gitee指数99分,在码云所有项目中排名第一,这个成绩很了不起,感谢所有开发者和AI ISV的贡献!

MindSpore Lite是一个极速、极智、极简、超轻量级的AI引擎,已随HMS应用于数百个手机侧应用程序,日调用量2.5亿次/天以上。欢迎大家参与开源贡献、模型众筹合作、行业创新与应用、学术合作等,贡献您在云侧、端侧、边侧以及安全领域的应用案例,也欢迎各位AI专业的套件公司与MindSpore社区合作共赢。

垃圾分类是近年全国人民关心的一个热点话题,本案例AI垃圾分类应用是基于MindSpore框架预置的图像分类预训练模型,用垃圾分类数据集在PC上进行迁移学习;然后用迁移学习训练好的模型基于MindSpore Lite部署到个人手机上实现垃圾分类应用,将分类图片进一步分类成可回收物、干垃圾、有害垃圾和湿垃圾四大类。

通过这个简单示例分享,大家能快速了解如何利用MindSpore框架进行端侧个性化应用开发的端对端过程。具体介绍前,我们先看看用MindSpore框架实现端侧垃圾分类个性化应用的效果。可以看到,我们把识别出的纸张分成可回收物、一次性筷子分成干垃圾、易拉罐分成可回收物、蛋壳分成湿垃圾, 这个应用可以辅助我们日常的垃圾分类。

端侧垃圾分类结果图

PC侧迁移学习步骤

在深度学习中,大部分任务的数据和网络模型规模较大,训练网络模型时,如果不使用预训练模型,从头开始训练,需要消耗大量的时间。因此,大部分任务都会选择预训练模型,在其上做微调。本端侧垃圾分类应用是基于MobileNetV2预训练模型,在其基础上进行迁移学习。

环境安装

1. 安装MindSpore框架

为方便大家快速上手了解,本用例选用Windows系统的MindSpore版本。当然,用户可根据个人系统和处理器架构安装对应版本MindSpore框架。

2. 下载代码

在Gitee中克隆MindSpore开源项目仓库,进入

./model_zoo/official/cv/mobilenetv2/

预训练模型准备

用户可以从CPU预训练模型下载提前训练好的预训练模型, 当然也可以参照model zoo自行训练预训练模型。为快速进行迁移学习,本demo直接选用提前训练好的预训练模型。

数据集准备

MobileNetV2的代码默认使用ImageFolder格式管理数据集,每一类图片整理成单独的一个文件夹, 数据集结构如下:

  1. └─ImageFolder
  2. ├─train
  3. class1Folder
  4. class2Folder
  5. │ ......
  6. └─eval
  7. class1Folder
  8. class2Folder
  9. ......

也可以参考MindSpore支持图像领域常用的数据集, 修改src/dataset.py

换成适合的数据集接口。

src/dataset.py中默认使用随机裁剪,随机水平翻转,随机色彩增强,正则化等基本的数据增强方法, 也可以参考数据增强添加新的数据处理算子。

模型训练和推理

一般一个分类网络包含两部分:backbone和head,其中backbone部分通常是一系列卷积层,负责提取图片特征;head部分通常是一组全连接层,用于分类,一般最后一个全连接层的输出对应数据集的分类数。同数据集和任务中特征提取层(卷积层)分布趋于一致,但是特征向量的组合(全连接层)不相同,分类数量(全连接层output_size)通常也不一致。这里我们选择冻结backbone部分的参数,只训练head的参数进行微调。

定义网络模型及加载预训练参数

首先按照代码第1行,构建MobileNetV2的backbone网络,head网络,并且构建包含这两个子网络的MobileNetV2网络。代码第3-10行展示了如何定义backbone_net与head_net,以及将两个子网络置入mobilenet_v2中。 代码第12-23行,展示了在微调训练模式下,需要将预训练模型加载入backbone_net子网络,并且冻结backbone_net中的参数,不参与训练。 代码第21-23行展示了如何冻结网络参数。

  1. backbone_net, head_net, net = define_net(args_opt, config)
  2. ...
  3. def define_net(config, is_training):
  4. backbone_net = MobileNetV2Backbone()
  5. activation = config.activation if not is_training else "None"
  6. head_net = MobileNetV2Head(input_channel=backbone_net.out_channels,num_classes=config.num_classes,activation=activation)
  7. net = mobilenet_v2(backbone_net, head_net)
  8. return backbone_net, head_net, net
  9. ...
  10. if args_opt.pretrain_ckpt and args_opt.freeze_layer == "backbone":
  11. load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False)
  12. ...
  13. def load_ckpt(network, pretrain_ckpt_path, trainable=True):
  14. """
  15. train the param weight or not
  16. """
  17. param_dict = load_checkpoint(pretrain_ckpt_path)
  18. load_param_into_net(network, param_dict)
  19. if not trainable:
  20. for param in network.get_parameters():
  21. param.requires_grad = False

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/222333
推荐阅读
相关标签
  

闽ICP备14008679号