当前位置:   article > 正文

基于cnn卷积神经网络的yolov8动物姿态估计识别(训练+代码)_姿态估计神经网络代码

姿态估计神经网络代码

往期热门博客项目回顾:

计算机视觉项目大集合

改进的yolo目标检测-测距测速

路径规划算法

图像去雨去雾+目标检测+测距项目

交通标志识别项目

yolo系列-重磅yolov9界面-最新的yolo

姿态识别-3d姿态识别

深度学习小白学习路线

  • 基于CNN(卷积神经网络)的YOLOv8模型在动物姿态估计识别方面是一种有效的解决方案。
  • YOLO(You Only Look
    Once)系列模型因其在实时目标检测中的高效性能而广受欢迎,YOLOv8是在YOLO家族中的一个更新版本,继承并改进了其前身YOLOv5的优点,增强了对于姿态估计任务的支持


在这里插入图片描述

以下是一些关键步骤和考虑因素,用于基于YOLOv8进行动物姿态估计识别的训练与代码实现:

数据准备:

  1. 数据集构建:首先,你需要一个包含标注好的动物姿态数据集,比如COCO、voc等,其中包含了动物图像及其对应的关节点坐标。

  2. 数据预处理:将数据集按照YOLOv8格式进行整理,确保图像尺寸适配模型要求,同时关键点数据转换为YOLOv8可接受的形式。
    在这里插入图片描述

模型配置与训练:

  1. 模型加载与配置:根据动物姿态识别任务调整模型配置文件(如.yaml文件),设置合适的anchor boxes、输入尺寸和输出层结构以适应姿态估计的需求。

  2. 关键点检测模块:YOLOv8在设计时可能整合了关键点检测的头结构,用于预测每个目标框内的关节点坐标。

  3. 训练流程:使用类似如下命令启动训练过程:

    python train.py --model yolov8n-pose --data your_dataset.yaml --hyp hyp.finetune.yaml --weights yolov8_base_weights.pt --epochs 200 --batch-size 16
    
    • 1

    其中,your_dataset.yaml是您的自定义数据集配置文件,yolov8_base_weights.pt是YOLOv8的基础权重,hyp.finetune.yaml是超参数配置文件。

代码实践:

  1. 训练脚本:参照ultralytics提供的训练脚本,修改和添加必要的代码以支持关键点的损失函数计算和后处理。

  2. 评估与验证:训练过程中会定期在验证集上评估模型的表现,检查关键点预测的准确性,可以利用内置的评价指标如平均绝对误差(MAE)或PCKh(Percentage of Correct Keypoints)等。

  3. 测试与推理:训练完成后,使用YOLOv8的推理脚本来加载训练好的模型,对新的未标注动物图像进行姿态估计。

注意事项:

  • YOLOv8对于姿态估计任务的改进可能包括但不限于:优化关键点预测的头结构、加入额外的损失函数来约束关键点之间的相对关系、以及优化训练策略以提高模型对姿态变化的鲁棒性。
  • 请确保在训练之前熟悉并正确配置所需的硬件资源,如GPU加速,并且合理调整训练参数以达到最佳性能。

为了获得详细的代码指导和训练教程,您可以查阅ultralytics团队发布的官方文档、GitHub仓库中的说明文件以及CSDN技术社区、阿里云等平台上的相关教程和技术博客。随着技术的发展,具体的训练代码和教程可能会有所更新,请务必查阅最新的官方指南和社区资源。

PR曲线

PR曲线(Precision-Recall Curve)和F1曲线虽然名称不常见,但在实际应用中,我们更多是指的是Precision-Recall曲线和F1分数(F1-score)的关系,而不是一个单独的“F1曲线”。

Precision-Recall Curve(PR曲线)
PR曲线是用来衡量二分类或多分类模型性能的一种可视化工具。它展示了模型在不同阈值设定下,Precision(查准率)和Recall(查全率)之间的权衡关系。Precision表示模型预测为正类的样本中有多少确实是正类,而Recall表示所有真实的正类样本中有多少被模型成功找了出来。

在PR曲线上,横坐标是Recall,纵坐标是Precision。当Recall增大时,Precision可能会下降,这是因为放宽预测条件会导致更多的样本被标记为正类,其中包括了真正的正类和假正类(False Positives)。理想情况下,希望模型在保持高Recall的同时,Precision也尽可能高。

在分析PR曲线时,人们常寻找能使F1-score最大化的阈值点,该点对应的Precision和Recall组合被认为是模型的最佳性能点。虽然PR曲线本身不是“F1曲线”,但是在分析PR曲线的过程中,我们常常会关注F1-score在各个阈值下的表现,尤其是在实践中寻找最合适的模型决策阈值时。而在某些场合下,为了综合考量模型性能,研究者可能会通过计算PR曲线下面积(Average Precision, AP)来得到一个整体评估值。
在这里插入图片描述

F1指数

F1-score
F1分数是Precision和Recall的一个综合度量,它是一个调和平均数,旨在给出一个单一数值来反映模型在这两个指标上的表现。F1-score的计算公式如下:

[ F1 = 2 \cdot \frac{Precision \times Recall}{Precision + Recall} ]

F1-score最大值为1,最小值为0。当Precision和Recall都很高的时候,F1-score也会很高。
在这里插入图片描述

代码

def prepare_predictions(
    image_dir_path,
    image_filename,
    model,
    BOX_IOU_THRESH = 0.55,
    BOX_CONF_THRESH=0.30,
    KPT_CONF_THRESH=0.68):

    image_path = os.path.join(image_dir_path, image_filename)
    image = cv2.imread(image_path).copy()

    results = model.predict(image_path, conf=BOX_CONF_THRESH, iou=BOX_IOU_THRESH)[0].cpu()

    if not len(results.boxes.xyxy):
        return image

    # Get the predicted boxes, conf scores and keypoints.
    pred_boxes = results.boxes.xyxy.numpy()
    pred_box_conf = results.boxes.conf.numpy()
    pred_kpts_xy = results.keypoints.xy.numpy()
    pred_kpts_conf = results.keypoints.conf.numpy()

    # Draw predicted bounding boxes, conf scores and keypoints on image.
    for boxes, score, kpts, confs in zip(pred_boxes, pred_box_conf, pred_kpts_xy, pred_kpts_conf):
        kpts_ids = np.where(confs > KPT_CONF_THRESH)[0]
        filter_kpts = kpts[kpts_ids]
        filter_kpts = np.concatenate([filter_kpts, np.expand_dims(kpts_ids, axis=-1)], axis=-1)
        image = draw_boxes(image, boxes, score=score)
        image = draw_landmarks(image, filter_kpts)

    return image

@dataclass(frozen=True)
class TrainingConfig:
    DATASET_YAML:   str = "animal-keypoints.yaml"
    MODEL:          str = "yolov8m-pose.pt"
    EPOCHS:         int = 100
    KPT_SHAPE:    tuple = (24,3)
    PROJECT:        str = "Animal_Keypoints"
    NAME:           str = f"{MODEL.split('.')[0]}_{EPOCHS}_epochs"
    CLASSES_DICT:  dict = field(default_factory = lambda:{0 : "dog"})

DATA_DIR = "animal-pose-data"


TRAIN_DIR         = f"train"
TRAIN_FOLDER_IMG    = f"images"
TRAIN_FOLDER_LABELS = f"labels"

TRAIN_IMG_PATH   = os.path.join(DATA_DIR, TRAIN_DIR, TRAIN_FOLDER_IMG)
TRAIN_LABEL_PATH = os.path.join(DATA_DIR, TRAIN_DIR, TRAIN_FOLDER_LABELS)

VALID_DIR           = f"valid"
VALID_FOLDER_IMG    = f"images"
VALID_FOLDER_LABELS = f"labels"
CLASS_ID = 0
# create_yolo_txt_files(train_json_data, TRAIN_LABEL_PATH)
# create_yolo_txt_files(val_json_data, VALID_LABEL_PATH)
ann_meta_data = pd.read_csv("keypoint_definitions.csv")
COLORS = ann_meta_data["Hex colour"].values.tolist()

COLORS_RGB_MAP = []
VALID_IMG_PATH   = os.path.join(DATA_DIR, VALID_DIR, VALID_FOLDER_IMG)
for color in COLORS:
    R, G, B = int(color[:2], 16), int(color[2:4], 16), int(color[4:], 16)
    COLORS_RGB_MAP.append({color: (R,G,B)})
VAL_IMAGE_FILES = os.listdir(VALID_IMG_PATH)

num_samples = 9
num_rows = 3
num_cols = num_samples//num_rows

fig, ax = plt.subplots(
        nrows=num_rows,
        ncols=num_cols,
        figsize=(25, 15),
    )

random.seed(90)
random.shuffle(VAL_IMAGE_FILES)
train_config = TrainingConfig()
ckpt_path  = os.path.join(train_config.PROJECT, train_config.NAME, "weights", "best.pt")
print("ckpt_pth",ckpt_path)
model_pose = YOLO(ckpt_path)

for idx, (file, axis) in enumerate(zip(VAL_IMAGE_FILES[:num_samples], ax.flat)):

    image_pred = prepare_predictions(VALID_IMG_PATH, file, model_pose)
    axis.imshow(image_pred[...,::-1])
    axis.axis("off")

plt.tight_layout(h_pad=4., w_pad=4.)
plt.show();
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93

最后:计算机视觉、图像处理、毕业辅导、作业帮助、代码获取,远程协助,代码定制,私聊会回复!

#code全部代码:qq1309399183

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/502726
推荐阅读
相关标签
  

闽ICP备14008679号