赞
踩
图像分类器。
学习资源:https://www.youtube.com/watch?v=Z-65nqxUdl4
@努力的小巴掌 记录计算机视觉学习道路上的所思所得。
划分数据集:train,val,test
知道怎么划分数据集很重要。
文件夹下面有不同类别的图片。
train
-----dog
-----cat
val
-----dog
-----cat
test
-----dog
-----cat
方法1:
在python写脚本
首先,确保自己已经安装了ultralytics和numpy。
可以直接创建requirements.txt文件,写上这个:
ultralytics==8.0.58
numpy==1.24.2
然后pip install requirements.txt
参考官网给的文档:
Classify - Ultralytics YOLO Docs
创建main.py
from ultralytics import YOLO
# Load a model
# model = YOLO("yolov8n-cls.yaml") # build a new model from YAML
model = YOLO("yolov8n-cls.pt") # load a pretrained model (recommended for training)
# model = YOLO("yolov8n-cls.yaml").load("yolov8n-cls.pt") # build from YAML and transfer weights
# Train the model
results = model.train(data="数据集的的绝对路径", epochs=1, imgsz=64)
在本地运行时候,只是为了看看train.py能不能正常运行,所以,epocha设置成1;
data="数据集的的绝对路径",这里是放所有图片的那个总文件夹,就是train/val/test上面一级的,然后注意一定是绝对路径。
方法2
命令行
yolo classify train data='绝对路径' model=yolov8n-cls.pt epochs=1 imgsz=64
结果保存在runs/classify下
结果有3个,
weights:best.pt和last.pt 模型文件
args.yaml: 类似于配置文件,列出了我们训练时候的所有参数
results.csv:所有epochs的训练结果
其中我们重点关注,loss和accuracy。
我们要保证其损失是一直下降的。
数字不好看,我们用每个epoch的loss值画一个图像,可以直观的看。
创建画图脚本plot_metrics.py
代码:
- import os
- import pandas as pd
- import matplotlib.pyplot as plt
-
-
- results_path = './runs/classify/train14/results.csv'
-
- results = pd.read_csv(results_path)
-
- plt.figure()
- plt.plot(results[' epoch'], results[' train/loss'], label='train loss')
- plt.plot(results[' epoch'], results[' val/loss'], label='val loss', c='red')
- plt.grid()
- plt.title('Loss vs epochs')
- plt.ylabel('loss')
- plt.xlabel('epochs')
- plt.legend()
-
-
- plt.figure()
- plt.plot(results[' epoch'], results[' metrics/accuracy_top1'] * 100)
- plt.grid()
- plt.title('Validation accuracy vs epochs')
- plt.ylabel('accuracy (%)')
- plt.xlabel('epochs')
-
- plt.show()
结果类似于:
创建predict.py
from ultralytics import YOLO
# Load a model
model = YOLO("path/to/best.pt") # load a custom model
# Predict with the model
results = model("图片位置") # predict on an image
names_dict = results[0].names
probs = results[0].probs.tolist()
print(names_dict)
print(probs)
print(names_dict[np.argmax(probs)])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。