当前位置:   article > 正文

深度学习——COCO全身关键点提取部分指定的关键点

深度学习——COCO全身关键点提取部分指定的关键点

使用yolov8训练人体关键点模型;
一个模型多个类别,不同类别关键点个数不一致;
我目前了解到的好像只有COCO是有全身关键点;
COCO全身关键点链接:https://github.com/jin-s13/COCO-WholeBody
在这里插入图片描述
以下代码能从COCO全身标注的json中提取出来想要的关键点和对应的类别;并且直接转换成了yolov8可用的训练txt格式,
注意:其中图片使用的os.link,类似Linux中的硬链接,并非复制,如果内存充足的情况下可以使用shutil.copy替换;

20240408-测试版本代码

# -*- coding: UTF-8 -*-
"""
@Project :ultralytics 
@IDE     :PyCharm 
@Author  :沐枫
@Date    :2024/4/8 15:11 
"""
import os
import json
import shutil
from typing import Dict, List
from concurrent import futures

from tqdm import tqdm
import cv2

COCO_URL_ROOT = "http://images.cocodataset.org"


class DecodeWholeBodyImage:
    """
    解析图片的字典信息
    """

    def __init__(self, image_info: Dict):
        self.license = image_info['license']
        self.date_captured = image_info['date_captured']
        self.flickr_url = image_info['flickr_url']

        self.id = image_info['id']
        self.image_id = image_info['id']  # 和annotation中的image_id一样,对应到一起可以找到对应的目标
        self.file_name = image_info['file_name']

        # 'http://images.cocodataset.org/val2017/000000397133.jpg'
        self.coco_url = image_info['coco_url']
        self.height = image_info['height']
        self.width = image_info['width']

        if 'http' not in self.flickr_url:
            self.url = self.coco_url
        else:
            self.url = self.flickr_url


class DecodeWholeBodyAnnotation:
    """
    一个目标的信息解析
    边界框格式是ltwh
    """

    def __init__(self, annotation: Dict):
        # 通过这个id找图片
        self.image_id = annotation['image_id']
        # 是否是人群,0:不是
        self.iscrowd = annotation['iscrowd']
        # 分割
        self.segmentation = annotation['segmentation']
        # 目标的id
        self.id = annotation['id']
        # 目标的类别索引
        self.category_id = annotation['category_id']

        # 身体关键点和box
        self.body_points = annotation['keypoints']
        self.body_box = annotation['bbox']
        self.num_keypoints = annotation['num_keypoints']  # 关键点有效个数

        # 脚关键点
        self.foot_points = annotation['foot_kpts']
        self.foot_valid = annotation['foot_valid']  # 脚关键点的有效性

        # 脸的关键点和box
        self.face_points = annotation['face_kpts']
        self.face_box = annotation['face_box']
        self.face_valid = annotation['face_valid']  # 有效性

        # left手关键点和box
        self.lefthand_box = annotation['lefthand_box']
        self.lefthand_points = annotation['lefthand_kpts']
        self.lefthand_valid = annotation['lefthand_valid']  # 有效性

        # right关键点和box
        self.righthand_box = annotation['righthand_box']
        self.righthand_points = annotation['righthand_kpts']
        self.righthand_valid = annotation['righthand_valid']  # 有效性

        # 把所有的关键点整合到一起
        self.all_points = list()
        self.all_points.extend(self.body_points)
        self.all_points.extend(self.foot_points)
        self.all_points.extend(self.face_points)
        self.all_points.extend(self.lefthand_points)
        self.all_points.extend(self.righthand_points)


def clip(value, min_v, max_v):
    if value < min_v:
        value = min_v

    if value > max_v:
        value = max_v

    return value


def ltwh2xywhn(bbox, img_h, img_w):
    """
    输入是COCO格式的box是ltwh,输出是归一化之后的xywhn,可以利用来训练yolo模型
    Args:
        bbox: ltwh
        img_h:
        img_w:

    Returns:

    """
    x1, y1, w, h = bbox  # ltwh

    x1 = clip(x1, 0, img_w)
    y1 = clip(y1, 0, img_h)
    x2 = clip(x1 + w, 0, img_w)
    y2 = clip(y1 + h, 0, img_h)

    w = x2 - x1
    h = y2 - y1

    # 计算box中心点坐标
    x = x1 + w / 2
    y = y1 + h / 2

    # 归一化
    x = x / img_w
    y = y / img_h
    w = w / img_w
    h = h / img_h

    return x, y, w, h


def get_point(point_index, all_points, img_shape_wh=None, max_point_num=0):
    """
    根据关键点索引从关键点list中找到对应的关键点并进行归一化后转成字符串格式,返回回去
    Args:
        point_index: 想要的关键点的索引
        all_points: 所有关键点的list
        img_shape_wh: (w, h),入股哦是None,就不归一化
        max_point_num: 关键点最多的个数

    Returns: str

    """
    current_point_num = len(point_index)
    # 保存结果的字符串
    res = ""
    if current_point_num > 0:
        # 先根据索引获取到想要的关键点
        for index in point_index:
            start = index * 3
            end = (index + 1) * 3

            x, y, v = all_points[start:end]
            # 对可视信息调整
            if 0 < v <= 1:
                v = 1
            if 1 < v <= 2:
                v = 2

            # 是否归一化
            if img_shape_wh is not None:
                img_w, img_h = img_shape_wh
                x = clip(x, 0, img_w) / img_w
                y = clip(y, 0, img_h) / img_h

            res += f"{x:.6f} {y:.6f} {int(v)} "

        # 如果关键点比较少,就使用全0填充
        if current_point_num < max_point_num:
            _temp = " ".join((["0"] * (max_point_num - current_point_num) * 3))
            res += _temp

    else:  # 没有指定关键点索引,使用全0代替
        _temp = " ".join((["0"] * MAX_POINT_NUM * 3))
        res += _temp

    return res.strip()


if __name__ == '__main__':
    data_root = r"Z:\Datasets\Detection\COCO2017"
    if data_root == "":
        raise ValueError(f"{data_root} should not be empty string")
    data_root = os.path.abspath(data_root)

    # 项目名称
    project = "FallAndSit"
    # 规定想保留的目标
    # cls_index指的是类别索引
    # box_type指的是该类别的边界框类型,
    # body_box指的是人体的边界框;face_box指的是人脸边界框;lefthand_box指的是左手边界框;righthand_box指的是右手边界框
    # point_index指的是该类别的关键点索引,整体的索引,会按照顺序取关键点
    BOX_TYPE = ("body_box", "face_box", "lefthand_box", "righthand_box",)
    POINT_INDEX_MAX = 129
    Object_info: List[Dict] = [
        {"cls_index": 0,
         "box_type": "body_box",
         "point_index": (6, 5, 12, 11, 14, 13, 16, 15)},

        {"cls_index": 1,
         "box_type": "face_box",
         "point_index": (2, 1, 4, 3, 71, 77)},

        # {"cls_index": 1,
        #  "box_type": "face_box",
        #  "point_index": tuple()},
    ]
    # 关键点最多的数量,用来对齐关键点的数量,如果不够的使用[0, 0, 0]填充
    MAX_POINT_NUM = 0
    for value in Object_info:
        MAX_POINT_NUM = max(MAX_POINT_NUM, len(value["point_index"]))

    if len(Object_info) == 0:
        raise ValueError("Object_dict is empty")

    image_root = os.path.join(data_root, project, "images")
    txt_root = os.path.join(data_root, project, "labels")

    if os.path.exists(image_root):
        shutil.rmtree(image_root)
    os.makedirs(image_root)
    if os.path.exists(txt_root):
        shutil.rmtree(txt_root)
    os.makedirs(txt_root)

    json_path_list = [
        os.path.join(data_root, "annotations", "coco-wholebody", "coco_wholebody_val_v1.0.json"),
        # os.path.join(data_root, "annotations", "coco-wholebody", "coco_wholebody_train_v1.0.json"),
    ]

    for json_path in json_path_list:
        # 保存数据
        information = dict()

        print(f"read {json_path}")
        # 读文件
        with open(json_path, 'r', encoding="utf-8") as rFile:
            json_data = json.load(rFile)
        print(f"read {json_path} finish ...")

        # 先处理图片
        print(f"deal images ...")
        # list:[dict ...]
        image_list = json_data['images']

        for step in tqdm(range(len(image_list)), desc=f"deal {os.path.basename(json_path)}"):
            # 下面这些可以写成一个函数,使用多线程处理
            img_info = DecodeWholeBodyImage(image_list[step])

            # 图片路径img_info.coco_url:'http://images.cocodataset.org/val2017/000000397133.jpg'
            # 原图路径
            img_path = os.path.join(data_root,
                                    img_info.coco_url.replace(COCO_URL_ROOT, "images").replace("/", os.sep))

            img = cv2.imread(img_path)
            if img is None:
                continue
            h, w = img.shape[:2]

            dst_img_path = img_path.replace(os.path.join(data_root, "images"), image_root)
            information[img_info.id] = {
                "file_name": img_info.file_name,  # 图片名称
                'h': h,  # 图片的高
                'w': w,  # 图片的宽
                "src_path": img_path,  # 原图路径
                "dst_path": dst_img_path,  # 该项目中目标路径
            }

        print("deal image information finish ...")
        # 收集好图片的信息之后,开始收集目标的信息
        print("deal annotation ...")

        annotations = json_data['annotations']
        for step in tqdm(range(len(annotations)), desc=f"deal {os.path.basename(json_path)}"):
            # 解析目标
            annotation = DecodeWholeBodyAnnotation(annotations[step])

            # 获取目标对应的图片的信息
            image_info = information[annotation.image_id]
            # 图片名
            file_name = image_info["file_name"]
            # 后缀
            _, suffix = os.path.splitext(file_name)
            # 原图路径
            src_image_path = image_info["src_path"]
            # 目标图路径
            dst_image_path = image_info["dst_path"]
            # 标签保存路径
            txt_path = dst_image_path.replace(image_root, txt_root).replace(suffix, ".txt")

            # 图片的宽高
            img_h = image_info['h']
            img_w = image_info['w']

            # 开始获取想要的关键点和目标
            results = list()
            for value in Object_info:
                cls_index = value["cls_index"]
                box_type = value["box_type"]
                assert box_type in BOX_TYPE, f"{box_type} not in {BOX_TYPE}"

                # 目标字符串
                res = ""
                if box_type == "body_box" and (not annotation.iscrowd):  # 不是人群,大密集的
                    box = ltwh2xywhn(annotation.body_box, img_h=img_h, img_w=img_w)
                    res += f"{cls_index} {box[0]:.6f} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f} "

                    # 关键点的索引tuple
                    point_index = value["point_index"]
                    # 关键点字符串
                    res += get_point(point_index=point_index,
                                     all_points=annotation.all_points,
                                     img_shape_wh=(img_w, img_h),
                                     max_point_num=MAX_POINT_NUM
                                     )

                elif box_type == "face_box" and annotation.face_valid:
                    box = ltwh2xywhn(annotation.face_box, img_h=img_h, img_w=img_w)
                    res += f"{cls_index} {box[0]:.6f} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f} "

                    # 关键点的索引tuple
                    point_index = value["point_index"]
                    # 关键点字符串
                    res += get_point(point_index=point_index,
                                     all_points=annotation.all_points,
                                     img_shape_wh=(img_w, img_h),
                                     max_point_num=MAX_POINT_NUM
                                     )

                elif box_type == "lefthand_box" and annotation.lefthand_valid:
                    box = ltwh2xywhn(annotation.lefthand_box, img_h=img_h, img_w=img_w)
                    res += f"{cls_index} {box[0]:.6f} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f} "

                    # 关键点的索引tuple
                    point_index = value["point_index"]
                    # 关键点字符串
                    res += get_point(point_index=point_index,
                                     all_points=annotation.all_points,
                                     img_shape_wh=(img_w, img_h),
                                     max_point_num=MAX_POINT_NUM
                                     )

                elif box_type == "righthand_box" and annotation.lefthand_valid:
                    box = ltwh2xywhn(annotation.righthand_box, img_h=img_h, img_w=img_w)
                    res += f"{cls_index} {box[0]:.6f} {box[1]:.6f} {box[2]:.6f} {box[3]:.6f} "

                    # 关键点的索引tuple
                    point_index = value["point_index"]
                    # 关键点字符串
                    res += get_point(point_index=point_index,
                                     all_points=annotation.all_points,
                                     img_shape_wh=(img_w, img_h),
                                     max_point_num=MAX_POINT_NUM,
                                     )

                #
                if res != "":
                    results.append(res)

            os.makedirs(os.path.dirname(txt_path), exist_ok=True)
            with open(txt_path, "a", encoding="utf-8") as wFile:
                for line in results:
                    wFile.write(f"{line}\n")

            # 映射图片
            if not os.path.exists(dst_image_path):
                os.makedirs(os.path.dirname(dst_image_path), exist_ok=True)
                os.link(src_image_path, dst_image_path)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377

示例:【人脸7个关键点,身体8个关键点】
在这里插入图片描述
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/394312
推荐阅读
相关标签
  

闽ICP备14008679号