当前位置:   article > 正文

昇思MindSpore 25天学习打卡营|day7

昇思MindSpore 25天学习打卡营|day7

模型训练

模型训练的步骤:

1. 构建数据集

2. 定义神经网络模型

3. 定义超参数、损失函数及优化器

4. 输入数据集进行训练与评估

构建数据集

首先从数据集 Dataset加载代码,构建数据集。

  1. import mindspore
  2. from mindspore import nn
  3. from mindspore.dataset import vision, transforms
  4. from mindspore.dataset import MnistDataset
  5. # Download data from open datasets
  6. from download import download
  7. url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
  8. "notebook/datasets/MNIST_Data.zip"
  9. path = download(url, "./", kind="zip", replace=True)
  10. def datapipe(path, batch_size):
  11. image_transforms = [
  12. vision.Rescale(1.0 / 255.0, 0),
  13. vision.Normalize(mean=(0.1307,), std=(0.3081,)),
  14. vision.HWC2CHW()
  15. ]
  16. label_transform = transforms.TypeCast(mindspore.int32)
  17. dataset = MnistDataset(path)
  18. # map操作是数据预处理的关键操作,可以针对数据集指定列(column)添加数据变换(Transforms),将数据变换应用于该列数据的每个元素,并返回包含变换后元素的新数据集。
  19. dataset = dataset.map(image_transforms, 'image')
  20. dataset = dataset.map(label_transform, 'label')
  21. dataset = dataset.batch(batch_size)
  22. return dataset
  23. train_dataset = datapipe('MNIST_Data/train', batch_size=64)
  24. test_dataset = datapipe('MNIST_Data/test', batch_size=64)
file_sizes: 100%|███████████████████████████| 10.8M/10.8M [00:00<00:00, 148MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

定义神经网络模型

从 06-网络构建中 加载代码,构建一个神经网络模型

  1. class Network(nn.Cell):
  2. def __init__(self):
  3. super().__init__()
  4. self.flatten = nn.Flatten()
  5. self.dense_relu_sequential = nn.SequentialCell(
  6. nn.Dense(28*28, 512),
  7. nn.ReLU(),
  8. nn.Dense(512, 512),
  9. nn.ReLU(),
  10. nn.Dense(512, 10)
  11. )
  12. def construct(self, x):
  13. x = self.flatten(x)
  14. logits = self.dense_relu_sequential(x)
  15. return logits
  16. model = Network()

定义超参、损失函数及优化器

超参

超参(Hyperparameters)是可以调整的参数,可以控制模型训练优化的过程,不同的超参数值可能会影响模型训练和收敛速度。目前深度学习模型多采用批量随机梯度下降算法进行优化,随机梯度下降算法的原理如下:

                        ​​​​​​​        ​​​​​​​        ​​​​​​​        w_{t+1}=w_{t}-\eta \frac{1}{n} \sum_{x \in \mathcal{B}} \nabla l\left(x, w_{t}\right)

公式中,

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/844446
推荐阅读
相关标签