赞
踩
今天主要的学习内容是MindSpore的快速入门,主要是跟着教程过了一遍图像分类问题的流程,主要包含数据获取和处理、网络构建、模型训练、保存模型和加载模型、预测推理五个流程,详细内容可以参考官方文档,这里只说一下我的心得。
这里主要使用download模块获取数据集,并使用MindSpore自带的mindspore.dataset.MnistDataset获取数据对象并处理,如果之前学过Pytorch或Tensorflow的应该可以很快上手这部分内容。
这里就是使用几个全连接层和Relu激活函数搭建的一个MLP网络,其搭建流程个人感觉和Pytorch差不多,其中construct函数类似于forword函数。
这里首先定义了损失函数、优化器、训练流程、测试流程等,如果学过神经网络相关的基础知识,总体上理解这部分的流程还是比较容易的,这里附上我的训练效果图
保存模型一行代码的事,没什么好说的,保存格式为.ckpt
加载模型时是先使用load_checkpoint函数导入模型,再使用load_param_into_net函数将模型参数导入到网络中,该函数返回值包括未被加载的参数列表,可以看到哪些参数未被加载,这是一个很实用的功能。
使用加载模型的网络对测试集中的图像进行预测得到最终结果
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。