赞
踩
使用在某个较大的数据集训练好的预训练模型,即被预置了参数的权重,可以帮助模型在新的数据集上更快收敛。尤其是对一些训练数据比较稀缺的任务,在神经网络参数十分庞大的情况下,仅仅依靠任务自身的训练数据可能无法训练充分,加载预训练模型的方法可以认为是让模型基于一个更好的初始状态进行学习,从而能够达到更好的性能。
由于使用 CPU 来进行模型训练,计算速度较慢,因此,此处以 ShuffleNetV2_x0_25
为例。此模型计算量较小,在 CPU 上计算速度较快。但是也因为模型较小,训练好的模型精度也不会太高。
# windows 在 cmd 中进入 PaddleClas 根目录,执行此命令
python tools/train.py \
-c ./ppcls/configs/quick_start/new_user/ShuffleNetV2_x0_25.yaml
-c
参数是指定训练的配置文件路径,训练的具体超参数可查看yaml
文件yaml
文件中Global.device
参数设置为cpu
,即使用 CPU 进行训练(若不设置,此参数默认为True
)yaml
文件中epochs
参数设置为 20,说明对整个数据集进行 20 个 epoch 迭代,预计训练 20 分钟左右(不同 CPU,训练时间略有不同),此时训练模型不充分。若提高训练模型精度,请将此参数设大,如 40,训练时间也会相应延长。python tools/train.py \
-c ./ppcls/configs/quick_start/new_user/ShuffleNetV2_x0_25.yaml
-o Arch.pretrained=True
-o
参数可以选择为 True 或 False,也可以是预训练模型存放路径,当选择为 True 时,预训练权重会自动下载到本地。注意:若为预训练模型路径,则不要加上:.pdparams
可以将使用与不使用预训练模型训练进行对比,观察 loss
的下降情况。
由于 GPU 训练速度更快,可以使用更复杂模型,因此以 ResNet50_vd
为例。与 ShuffleNetV2_x0_25
相比,此模型计算量较大, 训练好的模型精度也会更高。
首先要设置环境变量,使用 0 号 GPU 进行训练
linux 或者 mac
export CUDA_VISIBLE_DEVICES=0
windows
set CUDA_VISIBLE_DEVICES=0
python3 tools/train.py \
-c ./ppcls/configs/quick_start/ResNet50_vd.yaml
训练完成后,验证集的 Top1 Acc
曲线如下所示,最高准确率为0.2735。训练精度曲线下图所示
基于 ImageNet1k
分类预训练模型进行微调,训练脚本如下所示
python3 tools/train.py \
-c ./ppcls/configs/quick_start/ResNet50_vd.yaml \
-o Arch.pretrained=True
验证集的 Top1 Acc
曲线如下所示,最高准确率为 0.9402,加载预训练模型之后,flowers102 数据集精度大幅提升,绝对精度涨幅超过 65%。
如果训练任务因为其他原因被终止,也可以加载断点权重文件,继续训练:
python3 tools/train.py \
-c ./ppcls/configs/quick_start/MobileNetV1_retrieval.yaml \
-o Global.checkpoints="./output/RecModel/epoch_5" \
-o Global.device=gpu
其中配置文件不需要做任何修改,只需要在继续训练时设置 Global.checkpoints
参数即可,表示加载的断点权重文件路径,使用该参数会同时加载保存的断点权重和学习率、优化器等信息。
注意:
-o Global.checkpoints
参数无需包含断点权重文件的后缀名,上述训练命令会在训练过程中生成如下所示的断点权重文件,若想从断点 5
继续训练,则 Global.checkpoints
参数只需设置为 "./output/RecModel/epoch_5"
,PaddleClas 会自动补充后缀名。
output/
└── RecModel
├── best_model.pdopt
├── best_model.pdparams
├── best_model.pdstates
├── epoch_1.pdopt
├── epoch_1.pdparams
├── epoch_1.pdstates
.
.
.
可以通过以下命令进行模型评估。
python3 tools/eval.py \
-c ./ppcls/configs/quick_start/MobileNetV1_retrieval.yaml \
-o Global.pretrained_model=./output/RecModel/best_model
上述命令将使用 ./configs/quick_start/MobileNetV1_retrieval.yaml
作为配置文件,对上述训练得到的模型 ./output/RecModel/best_model
进行评估。你也可以通过更改配置文件中的参数来设置评估,也可以通过 -o
参数更新配置,如上所示。
可配置的部分评估参数说明如下:
Arch.name
:模型名称Global.pretrained_model
:待评估的模型的预训练模型文件路径,不同于Global.Backbone.pretrained
,此处的预训练模型是整个模型的权重,而Global.Backbone.pretrained
只是Backbone部分的权重。当需要做模型评估时,需要加载整个模型的权重。Metric.Eval
:待评估的指标,默认评估recall@1、recall@5、mAP。当你不准备评测某一项指标时,可以将对应的试标从配置文件中删除;当你想增加某一项评测指标时,也可以参考Metric部分在配置文件Metric.Eval
中添加相关的指标。注意:
.pdparams
的后缀。如果机器环境为 Linux+GPU,那么推荐使用paddle.distributed.launch
启动模型训练脚本(tools/train.py
)、评估脚本(tools/eval.py
),可以更方便地启动多卡训练与评估。
# PaddleClas 通过 launch 方式启动多卡多进程训练
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
tools/train.py \
-c ./ppcls/configs/quick_start/MobileNetV3_large_x1_0.yaml
根据自己的数据集配置好配置文件之后,可以加载预训练模型进行微调,如下所示。
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
tools/train.py \
-c ./ppcls/configs/quick_start/MobileNetV3_large_x1_0.yaml \
-o Arch.pretrained=True
其中Arch.pretrained
为True
或False
,当然也可以设置加载预训练权重文件的路径,使用时需要换成自己的预训练模型权重文件路径,也可以直接在配置文件中修改该路径。
如果训练任务因为其他原因被终止,也可以加载断点权重文件继续训练。
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
tools/train.py \
-c ./ppcls/configs/quick_start/MobileNetV3_large_x1_0.yaml \
-o Global.checkpoints="./output/MobileNetV3_large_x1_0/epoch_5" \
-o Global.device=gpu
其中配置文件不需要做任何修改,只需要在训练时设置Global.checkpoints
参数即可,该参数表示加载的断点权重文件路径,使用该参数会同时加载保存的模型参数权重和学习率、优化器等信息。
可以通过以下命令进行模型评估。
export CUDA_VISIBLE_DEVICES=0,1,2,3
python3 -m paddle.distributed.launch \
tools/eval.py \
-c ./ppcls/configs/quick_start/MobileNetV3_large_x1_0.yaml \
-o Global.pretrained_model=./output/MobileNetV3_large_x1_0/best_model
训练完成后预测代码如下
cd $path_to_PaddleClas
# `Infer.infer_imgs`:待预测的图片文件路径或者批量预测时的图片文件夹
# `Global.pretrained_model`:模型权重文件路径
python3 tools/infer.py \
-c ./ppcls/configs/quick_start/new_user/ShuffleNetV2_x0_25.yaml \
-o Infer.infer_imgs=dataset/flowers102/jpg/image_00001.jpg \
-o Global.pretrained_model=output/ShuffleNetV2_x0_25/best_model
-i
输入为单张图像路径,运行成功后,示例结果如下:
[{'class_ids': [76, 65, 34, 9, 69], 'scores': [0.91762, 0.01801, 0.00833, 0.0071, 0.00669], 'file_name': 'dataset/flowers102/jpg/image_00001.jpg', 'label_names': []}]
-i
输入为图像集所在目录,运行成功后,示例结果如下:
[{'class_ids': [76, 65, 34, 9, 69], 'scores': [0.91762, 0.01801, 0.00833, 0.0071, 0.00669], 'file_name': 'dataset/flowers102/jpg/image_00001.jpg', 'label_names': []}, {'class_ids': [76, 69, 34, 28, 9], 'scores': [0.77122, 0.06295, 0.02537, 0.02531, 0.0251], 'file_name': 'dataset/flowers102/jpg/image_00002.jpg', 'label_names': []}, {'class_ids': [99, 76, 81, 85, 16], 'scores': [0.26374, 0.20423, 0.07818, 0.06042, 0.05499], 'file_name': 'dataset/flowers102/jpg/image_00003.jpg', 'label_names': []}, {'class_ids': [9, 37, 34, 24, 76], 'scores': [0.17784, 0.16651, 0.14539, 0.12096, 0.04816], 'file_name': 'dataset/flowers102/jpg/image_00004.jpg', 'label_names': []}, {'class_ids': [76, 66, 91, 16, 13], 'scores': [0.95494, 0.00688, 0.00596, 0.00352, 0.00308], 'file_name': 'dataset/flowers102/jpg/image_00005.jpg', 'label_names': []}, {'class_ids': [76, 66, 34, 8, 43], 'scores': [0.44425, 0.07487, 0.05609, 0.05609, 0.03667], 'file_name': 'dataset/flowers102/jpg/image_00006.jpg', 'label_names': []}, {'class_ids': [86, 93, 81, 22, 21], 'scores': [0.44714, 0.13582, 0.07997, 0.0514, 0.03497], 'file_name': 'dataset/flowers102/jpg/image_00007.jpg', 'label_names': []}, {'class_ids': [13, 76, 81, 18, 97], 'scores': [0.26771, 0.1734, 0.06576, 0.0451, 0.03986], 'file_name': 'dataset/flowers102/jpg/image_00008.jpg', 'label_names': []}, {'class_ids': [34, 76, 8, 5, 9], 'scores': [0.67224, 0.31896, 0.00241, 0.00227, 0.00102], 'file_name': 'dataset/flowers102/jpg/image_00009.jpg', 'label_names': []}, {'class_ids': [76, 34, 69, 65, 66], 'scores': [0.95185, 0.01101, 0.00875, 0.00452, 0.00406], 'file_name': 'dataset/flowers102/jpg/image_00010.jpg', 'label_names': []}]
其中,列表的长度为 batch_size 的大小。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。