赞
踩
。
以下是一些关键步骤和考虑因素,用于基于YOLOv8进行动物姿态估计识别的训练与代码实现:
数据集构建:首先,你需要一个包含标注好的动物姿态数据集,比如COCO、voc等,其中包含了动物图像及其对应的关节点坐标。
数据预处理:将数据集按照YOLOv8格式进行整理,确保图像尺寸适配模型要求,同时关键点数据转换为YOLOv8可接受的形式。
模型加载与配置:根据动物姿态识别任务调整模型配置文件(如.yaml
文件),设置合适的anchor boxes、输入尺寸和输出层结构以适应姿态估计的需求。
关键点检测模块:YOLOv8在设计时可能整合了关键点检测的头结构,用于预测每个目标框内的关节点坐标。
训练流程:使用类似如下命令启动训练过程:
python train.py --model yolov8n-pose --data your_dataset.yaml --hyp hyp.finetune.yaml --weights yolov8_base_weights.pt --epochs 200 --batch-size 16
其中,your_dataset.yaml
是您的自定义数据集配置文件,yolov8_base_weights.pt
是YOLOv8的基础权重,hyp.finetune.yaml
是超参数配置文件。
训练脚本:参照ultralytics提供的训练脚本,修改和添加必要的代码以支持关键点的损失函数计算和后处理。
评估与验证:训练过程中会定期在验证集上评估模型的表现,检查关键点预测的准确性,可以利用内置的评价指标如平均绝对误差(MAE)或PCKh(Percentage of Correct Keypoints)等。
测试与推理:训练完成后,使用YOLOv8的推理脚本来加载训练好的模型,对新的未标注动物图像进行姿态估计。
为了获得详细的代码指导和训练教程,您可以查阅ultralytics团队发布的官方文档、GitHub仓库中的说明文件以及CSDN技术社区、阿里云等平台上的相关教程和技术博客。随着技术的发展,具体的训练代码和教程可能会有所更新,请务必查阅最新的官方指南和社区资源。
PR曲线(Precision-Recall Curve)和F1曲线虽然名称不常见,但在实际应用中,我们更多是指的是Precision-Recall曲线和F1分数(F1-score)的关系,而不是一个单独的“F1曲线”。
Precision-Recall Curve(PR曲线):
PR曲线是用来衡量二分类或多分类模型性能的一种可视化工具。它展示了模型在不同阈值设定下,Precision(查准率)和Recall(查全率)之间的权衡关系。Precision表示模型预测为正类的样本中有多少确实是正类,而Recall表示所有真实的正类样本中有多少被模型成功找了出来。
在PR曲线上,横坐标是Recall,纵坐标是Precision。当Recall增大时,Precision可能会下降,这是因为放宽预测条件会导致更多的样本被标记为正类,其中包括了真正的正类和假正类(False Positives)。理想情况下,希望模型在保持高Recall的同时,Precision也尽可能高。
在分析PR曲线时,人们常寻找能使F1-score最大化的阈值点,该点对应的Precision和Recall组合被认为是模型的最佳性能点。虽然PR曲线本身不是“F1曲线”,但是在分析PR曲线的过程中,我们常常会关注F1-score在各个阈值下的表现,尤其是在实践中寻找最合适的模型决策阈值时。而在某些场合下,为了综合考量模型性能,研究者可能会通过计算PR曲线下面积(Average Precision, AP)来得到一个整体评估值。
F1-score:
F1分数是Precision和Recall的一个综合度量,它是一个调和平均数,旨在给出一个单一数值来反映模型在这两个指标上的表现。F1-score的计算公式如下:
[ F1 = 2 \cdot \frac{Precision \times Recall}{Precision + Recall} ]
F1-score最大值为1,最小值为0。当Precision和Recall都很高的时候,F1-score也会很高。
def prepare_predictions(
image_dir_path,
image_filename,
model,
BOX_IOU_THRESH = 0.55,
BOX_CONF_THRESH=0.30,
KPT_CONF_THRESH=0.68):
image_path = os.path.join(image_dir_path, image_filename)
image = cv2.imread(image_path).copy()
results = model.predict(image_path, conf=BOX_CONF_THRESH, iou=BOX_IOU_THRESH)[0].cpu()
if not len(results.boxes.xyxy):
return image
# Get the predicted boxes, conf scores and keypoints.
pred_boxes = results.boxes.xyxy.numpy()
pred_box_conf = results.boxes.conf.numpy()
pred_kpts_xy = results.keypoints.xy.numpy()
pred_kpts_conf = results.keypoints.conf.numpy()
# Draw predicted bounding boxes, conf scores and keypoints on image.
for boxes, score, kpts, confs in zip(pred_boxes, pred_box_conf, pred_kpts_xy, pred_kpts_conf):
kpts_ids = np.where(confs > KPT_CONF_THRESH)[0]
filter_kpts = kpts[kpts_ids]
filter_kpts = np.concatenate([filter_kpts, np.expand_dims(kpts_ids, axis=-1)], axis=-1)
image = draw_boxes(image, boxes, score=score)
image = draw_landmarks(image, filter_kpts)
return image
@dataclass(frozen=True)
class TrainingConfig:
DATASET_YAML: str = "animal-keypoints.yaml"
MODEL: str = "yolov8m-pose.pt"
EPOCHS: int = 100
KPT_SHAPE: tuple = (24,3)
PROJECT: str = "Animal_Keypoints"
NAME: str = f"{MODEL.split('.')[0]}_{EPOCHS}_epochs"
CLASSES_DICT: dict = field(default_factory = lambda:{0 : "dog"})
DATA_DIR = "animal-pose-data"
TRAIN_DIR = f"train"
TRAIN_FOLDER_IMG = f"images"
TRAIN_FOLDER_LABELS = f"labels"
TRAIN_IMG_PATH = os.path.join(DATA_DIR, TRAIN_DIR, TRAIN_FOLDER_IMG)
TRAIN_LABEL_PATH = os.path.join(DATA_DIR, TRAIN_DIR, TRAIN_FOLDER_LABELS)
VALID_DIR = f"valid"
VALID_FOLDER_IMG = f"images"
VALID_FOLDER_LABELS = f"labels"
CLASS_ID = 0
# create_yolo_txt_files(train_json_data, TRAIN_LABEL_PATH)
# create_yolo_txt_files(val_json_data, VALID_LABEL_PATH)
ann_meta_data = pd.read_csv("keypoint_definitions.csv")
COLORS = ann_meta_data["Hex colour"].values.tolist()
COLORS_RGB_MAP = []
VALID_IMG_PATH = os.path.join(DATA_DIR, VALID_DIR, VALID_FOLDER_IMG)
for color in COLORS:
R, G, B = int(color[:2], 16), int(color[2:4], 16), int(color[4:], 16)
COLORS_RGB_MAP.append({color: (R,G,B)})
VAL_IMAGE_FILES = os.listdir(VALID_IMG_PATH)
num_samples = 9
num_rows = 3
num_cols = num_samples//num_rows
fig, ax = plt.subplots(
nrows=num_rows,
ncols=num_cols,
figsize=(25, 15),
)
random.seed(90)
random.shuffle(VAL_IMAGE_FILES)
train_config = TrainingConfig()
ckpt_path = os.path.join(train_config.PROJECT, train_config.NAME, "weights", "best.pt")
print("ckpt_pth",ckpt_path)
model_pose = YOLO(ckpt_path)
for idx, (file, axis) in enumerate(zip(VAL_IMAGE_FILES[:num_samples], ax.flat)):
image_pred = prepare_predictions(VALID_IMG_PATH, file, model_pose)
axis.imshow(image_pred[...,::-1])
axis.axis("off")
plt.tight_layout(h_pad=4., w_pad=4.)
plt.show();
#code全部代码:qq1309399183
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。