当前位置:   article > 正文

paddlepaddle(四)训练与预测验证_model.prepare

model.prepare

目录

1.内置高级API封装训练

1.1使用paddle.Model()封装模型

1.2用Model.prepare()配置模型

1.3用Model.fit()训练模型

1.4用Model.evaluate()评估模型

1.5用Model.predict()预测模型

2.基础API实现训练与预测

2.1初始化优化器、模型、超参数

 2.2训练

2.3用基础API验证模型

2.5预测


1.内置高级API封装训练

1.1使用paddle.Model()封装模型

  • network (paddle.nn.Layer): 是 paddle.nn.Layer 的一个实例

  • inputs (InputSpec|list|dict|None, 可选): network 的输入,可以是 InputSpec 的实例,或者是一个 InputSpec 的 list ,或者是格式为 {name: InputSpec} 的 dict ,或者为 None . 默认值为 None .

  • labels (InputSpec|list|None, 可选): network 的标签,可以是 InputSpec 的实例,或者是一个 InputSpec 的 list ,或者为 None. 默认值为 None .

  1. # 定义网络结构(采用 Sequential组网方式 )
  2. mnist = paddle.nn.Sequential(
  3. paddle.nn.Flatten(1, -1),
  4. paddle.nn.Linear(784, 512),
  5. paddle.nn.ReLU(),
  6. paddle.nn.Dropout(0.2),
  7. paddle.nn.Linear(512, 10)
  8. )
  9. model = paddle.Model(mnist)

1.2用Model.prepare()配置模型

  1. # 为模型训练做准备,设置优化器,损失函数和精度计算方式
  2. model.prepare(optimizer=paddle.optimizer.Adam(parameters=model.parameters()),
  3. loss=paddle.nn.CrossEntropyLoss(),
  4. metrics=paddle.metric.Accuracy())

1.3用Model.fit()训练模型

  1. # 启动模型训练,指定训练数据集,设置训练轮次,设置每次数据集计算的批次大小,设置日志格式
  2. model.fit(train_dataset,
  3. epochs=5,
  4. batch_size=64,
  5. verbose=1)
  6. #####
  7. Epoch 2/5
  8. step 938/938 [==============================] - loss: 0.0555 - acc: 0.9689 - 22ms/step
  9. Epoch 3/5
  10. step 938/938 [==============================] - loss: 0.0487 - acc: 0.9781 - 22ms/step
  11. Epoch 4/5
  12. step 938/938 [==============================] - loss: 0.0061 - acc: 0.9837 - 22ms/step
  13. Epoch 5/5
  14. step 938/938 [==============================] - loss: 0.0900 - acc: 0.9866 - 22ms/step

1.4用Model.evaluate()评估模型

对于训练好的模型进行评估可以使用evaluate接口,事先定义好用于评估使用的数据集后,直接调用evaluate接口即可完成模型评估操作,结束后根据在preparelossmetric的定义来进行相关评估结果计算返回。

  1. # 用 evaluate 在测试集上对模型进行验证
  2. eval_result = model.evaluate(test_dataset, verbose=1)
  3. ####
  4. Eval begin...
  5. step 10000/10000 [==============================] - loss: 3.5763e-07 - acc: 0.9794 - 2ms/step
  6. Eval samples: 10000

1.5用Model.predict()预测模型

  1. # 用 predict 在测试集上对模型进行测试
  2. test_result = model.predict(test_dataset)
  3. print(test_result[0:10])
  4. #####
  5. Predict begin...
  6. step 10000/10000 [==============================] - 2ms/step
  7. Predict samples: 10000
  8. [(array([[ -4.894136 , -5.6600327 , -3.0495958 , 2.1171691 ,
  9. -11.302892 , -5.483228 , -14.45905 , 11.745611 ,
  10. -3.6458166 , -0.17910889]], dtype=float32),
  11. array([[ -5.714625 , 4.4638777, 13.364223 , 2.4783947, -20.74056 ,
  12. -0.6489969, -4.878751 , -15.684789 , -1.269482 , -15.168117 ]],
  13. dtype=float32),
  14. array([[-7.527398 , 8.128369 , -3.2447886 , -5.72277 , -1.6012611 ,
  15. -4.228486 , -2.8881369 , -0.27592087, -1.8160758 , -8.265086 ]],
  16. dtype=float32),
  17. array([[ 9.048569 , -7.1954203, -2.1176505, -3.7064955, -7.256583 ,
  18. -1.7620262, 0.7889428, -2.5328965, -6.511542 , -1.1649128]],
  19. dtype=float32)]

2.基础API实现训练与预测

  1. # 定义网络结构( 采用SubClass 组网 )
  2. class Mnist(paddle.nn.Layer):
  3. def __init__(self):
  4. super(Mnist, self).__init__()
  5. self.flatten = paddle.nn.Flatten()
  6. self.linear_1 = paddle.nn.Linear(784, 512)
  7. self.linear_2 = paddle.nn.Linear(512, 10)
  8. self.relu = paddle.nn.ReLU()
  9. self.dropout = paddle.nn.Dropout(0.2)
  10. def forward(self, inputs):
  11. y = self.flatten(inputs)
  12. y = self.linear_1(y)
  13. y = self.relu(y)
  14. y = self.dropout(y)
  15. y = self.linear_2(y)
  16. return y

 

2.1初始化优化器、模型、超参数

  1. # 用 DataLoader 实现数据加载
  2. train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
  3. #初始化模型,设为train模式
  4. mnist=Mnist()
  5. mnist.train()
  1. # 设置迭代次数
  2. epochs = 5
  3. # 设置优化器
  4. optim = paddle.optimizer.Adam(parameters=mnist.parameters())
  5. # 设置损失函数
  6. loss_fn = paddle.nn.CrossEntropyLoss()

 2.2训练

核心步骤

  1. acc = paddle.metric.accuracy(predicts, y_data)
  2. loss.backward()
  3. optim.step()
  4. optim.clear_grad()
  1. for epoch in range(epochs):
  2. for batch_id, data in enumerate(train_loader()):
  3. x_data = data[0] # 训练数据
  4. y_data = data[1] # 训练数据标签
  5. predicts = mnist(x_data) # 预测结果
  6. # 计算损失 等价于 prepare 中loss的设置
  7. loss = loss_fn(predicts, y_data)
  8. # 计算准确率 等价于 prepare 中metrics的设置
  9. acc = paddle.metric.accuracy(predicts, y_data)
  10. # 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() 中
  11. # 反向传播
  12. loss.backward()
  13. if (batch_id+1) % 900 == 0:
  14. print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id+1, loss.numpy(), acc.numpy()))
  15. # 更新参数
  16. optim.step()
  17. # 梯度清零
  18. optim.clear_grad()
  19. #####
  20. epoch: 0, batch_id: 900, loss is: [0.13427278], acc is: [0.96875]
  21. ...
  22. epoch: 4, batch_id: 900, loss is: [0.05088746], acc is: [0.96875]

2.3用基础API验证模型

初始化数据和评估函数

  1. # 加载测试数据集
  2. test_loader = paddle.io.DataLoader(test_dataset, batch_size=64, drop_last=True)
  3. loss_fn = paddle.nn.CrossEntropyLoss()
  4. mnist.eval()

评估 

  1. for batch_id, data in enumerate(test_loader()):
  2. x_data = data[0] # 测试数据
  3. y_data = data[1] # 测试数据标签
  4. predicts = mnist(x_data) # 预测结果
  5. # 计算损失与精度
  6. loss = loss_fn(predicts, y_data)
  7. acc = paddle.metric.accuracy(predicts, y_data)
  8. # 打印信息
  9. if (batch_id+1) % 30 == 0:
  10. print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id+1, loss.numpy(), acc.numpy()))
  11. #######
  12. batch_id: 30, loss is: [0.13358988], acc is: [0.96875]
  13. batch_id: 60, loss is: [0.17515801], acc is: [0.9375]
  14. batch_id: 90, loss is: [0.05868918], acc is: [0.96875]
  15. batch_id: 120, loss is: [0.00206844], acc is: [1.]
  16. batch_id: 150, loss is: [0.07185207], acc is: [0.984375]

2.5预测

  1. # 加载测试数据集
  2. test_loader = paddle.io.DataLoader(test_dataset, batch_size=64, drop_last=True)
  3. mnist.eval()
  4. for batch_id, data in enumerate(test_loader()):
  5. x_data = data[0]
  6. predicts = mnist(x_data)
  7. # 获取预测结果
  8. print("predict finished")

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/136795?site
推荐阅读
相关标签
  

闽ICP备14008679号