当前位置:   article > 正文

【技术文档】centernet(姿态估计)_kps.unsqueeze(3).expand(batch, num_joints, k, k, 2

kps.unsqueeze(3).expand(batch, num_joints, k, k, 2)

模型结构

backbone dla34

dla(Deep Layer Aggregation)
We introduce two structures for deep layer aggregation (DLA): iterative deep aggrega-
tion (IDA) and hierarchical deep aggregation (HDA).
Hierarchical deep aggregation merges blocks and stages in a tree to preserve and combine feature channels.
我们介绍两种结构深层聚合(DLA):迭代深层聚合 (IDA)和层次深度聚合(HDA)。
IDA focuses on fusing resolutions and scales while HDA focuses on merging features from all modules and channels.
IDA主要负责不同空间尺度信息的融合,HDA侧重于合并来自所有模块和通道的特性。
HDA

在这里插入图片描述
基本卷积结构block

class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, dilation=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3,
                               stride=stride, padding=dilation,
                               bias=False, dilation=dilation)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=dilation,
                               bias=False, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.stride = stride

    def forward(self, x, residual=None):
        if residual is None:
            residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += residual
        out = self.relu(out)

        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

构成具备树形结构的模块

class Root(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, residual):
        super(Root, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, 1,
            stride=1, bias=False, padding=(kernel_size - 1) // 2)
        self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.residual = residual

    def forward(self, *x):
        children = x
        x = self.conv(torch.cat(x, 1))
        x = self.bn(x)
        if self.residual:
            x += children[0]
        x = self.relu(x)

        return x


class Tree(nn.Module):
    def __init__(self, levels, block, in_channels, out_channels, stride=1,
                 level_root=False, root_dim=0, root_kernel_size=1,
                 dilation=1, root_residual=False):
        super(Tree, self).__init__()
        if root_dim == 0:
            root_dim = 2 * out_channels
        if level_root:
            root_dim += in_channels
        if levels == 1:
            self.tree1 = block(in_channels, out_channels, stride,
                               dilation=dilation)
            self.tree2 = block(out_channels, out_channels, 1,
                               dilation=dilation)
        else:
            self.tree1 = Tree(levels - 1, block, in_channels, out_channels,
                              stride, root_dim=0,
                              root_kernel_size=root_kernel_size,
                              dilation=dilation, root_residual=root_residual)
            self.tree2 = Tree(levels - 1, block, out_channels, out_channels,
                              root_dim=root_dim + out_channels,
                              root_kernel_size=root_kernel_size,
                              dilation=dilation, root_residual=root_residual)
        if levels == 1:
            self.root = Root(root_dim, out_channels, root_kernel_size,
                             root_residual)
        self.level_root = level_root
        self.root_dim = root_dim
        self.downsample = None
        self.project = None
        self.levels = levels
        if stride > 1:
            self.downsample = nn.MaxPool2d(stride, stride=stride)
        if in_channels != out_channels:
            self.project = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM)
            )

    def forward(self, x, residual=None, children=None):
        children = [] if children is None else children
        bottom = self.downsample(x) if self.downsample else x
        residual = self.project(bottom) if self.project else bottom
        if self.level_root:
            children.append(bottom)
        x1 = self.tree1(x, residual)
        if self.levels == 1:
            x2 = self.tree2(x1)
            x = self.root(x2, x1, *children)
        else:
            children.append(x1)
            x = self.tree2(x1, children=children)
        return x

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76

HDA
IDA
IDA
IDAUp
IDAUp是IDA模块需要不断重复使用的操作,每一次操作都需要在上此次操作结果的基础上。如上图所示,每一层的结果都是由IDAUp生成,连续使用就形成了现在的金字塔结构


class IDAUp(nn.Module):

    def __init__(self, o, channels, up_f):
        super(IDAUp, self).__init__()
        for i in range(1, len(channels)):
            c = channels[i]
            f = int(up_f[i])  
            proj = DeformConv(c, o)
            node = DeformConv(o, o)
     
            up = nn.ConvTranspose2d(o, o, f * 2, stride=f, 
                                    padding=f // 2, output_padding=0,
                                    groups=o, bias=False)
            fill_up_weights(up)

            setattr(self, 'proj_' + str(i), proj)
            setattr(self, 'up_' + str(i), up)
            setattr(self, 'node_' + str(i), node)
                 
        
    def forward(self, layers, startp, endp):
        for i in range(startp + 1, endp):
            upsample = getattr(self, 'up_' + str(i - startp))
            project = getattr(self, 'proj_' + str(i - startp))
            layers[i] = upsample(project(layers[i]))
            node = getattr(self, 'node_' + str(i - startp))
            layers[i] = node(layers[i] + layers[i - 1])

class DLAUp(nn.Module):
    def __init__(self, startp, channels, scales, in_channels=None):
        super(DLAUp, self).__init__()
        self.startp = startp
        if in_channels is None:
            in_channels = channels
        self.channels = channels
        channels = list(channels)
        scales = np.array(scales, dtype=int)
        for i in range(len(channels) - 1):
            j = -i - 2
            setattr(self, 'ida_{}'.format(i),
                    IDAUp(channels[j], in_channels[j:],
                          scales[j:] // scales[j]))
            scales[j + 1:] = scales[j]
            in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]]

    def forward(self, layers):
        out = [layers[-1]] # start with 32
        for i in range(len(layers) - self.startp - 1):
            ida = getattr(self, 'ida_{}'.format(i))
            ida(layers, len(layers) -i - 2, len(layers))
            out.insert(0, layers[-1])
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53

DLA总体结构

class DLASeg(nn.Module):
    def __init__(self, base_name, pretrained, down_ratio, final_kernel,
                 last_level, out_channel=0):
        super(DLASeg, self).__init__()
        assert down_ratio in [2, 4, 8, 16]
        self.first_level = int(np.log2(down_ratio))  # down_ratio=4
        self.last_level = last_level
        self.base = globals()[base_name](pretrained=pretrained)
        channels = self.base.channels
        scales = [2 ** i for i in range(len(channels[self.first_level:]))]
        self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales)

        if out_channel == 0:
            out_channel = channels[self.first_level]

        self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level], 
                            [2 ** i for i in range(self.last_level - self.first_level)])
                                 

    def forward(self, x):
        x = self.base(x)
        x = self.dla_up(x)

        y = []
        for i in range(self.last_level - self.first_level):
            y.append(x[i].clone())
        self.ida_up(y, 0, len(y))

        x = y[-1]
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

head

class KeypointHead(nn.Module):

    def __init__(self, intermediate_channel, head_conv):
        super(KeypointHead, self).__init__()    
        
        self.hm = nn.Sequential(
                    nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(head_conv, 1, kernel_size=1, stride=1, padding=0))
        self.wh = nn.Sequential(
                    nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0))
        self.hps = nn.Sequential(
                    nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(head_conv, 34, kernel_size=1, stride=1, padding=0))                  
        self.reg = nn.Sequential(
                    nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0))        
        self.hm_hp = nn.Sequential(
                    nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(head_conv, 17, kernel_size=1, stride=1, padding=0))                                           
        self.hp_offset = nn.Sequential(
                        nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True),
                        nn.ReLU(inplace=True),
                        nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0))                      
        self.init_weights()
                                           
    def forward(self, x):
        
        return [self.hm(x), self.wh(x), self.hps(x), self.reg(x), self.hm_hp(x), self.hp_offset(x)]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

train

dataloader

模型的输入大小(input_res)是512512,模型输出特征图(output_res)的大小是128128。
需要根据input_res做一次仿射变换,根据output_size做一次仿射变换,并得到两个仿射变换的矩阵。

groundtruth根据输出大小进行仿射变换后得到新的bounding box坐标点,该bounding box计算目标的中心点为正样本点,其他位置都是负样本。
通过目标中心点找到对应其他任务的输出结果(比如wh),计算loss。

目标的中心点和骨骼点对应的模型输出都是热力图形式的,根据目标的中心点进行高斯过滤的。

def __getitem__(self, index):
        # get img_id through index
        img_id = self.images[index]
        # get img_name by img_id
        file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']
        # get img_path by combining dataset_path and img_name
        img_path = os.path.join(self.img_dir, file_name)
        # get all annotation_ids through img_id
        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        # get all annotations through ann_ids
        anns = self.coco.loadAnns(ids=ann_ids)

        # select annotions which category_id in self._valid_ids and is not crowd labeled
        anns = list(filter(lambda x:x['category_id'] in self._valid_ids and x['iscrowd']!= 1 , anns))
        # limit the numbers of objects in an image
        num_objs = min(len(anns), self.max_objs)
        # read the image
        img = cv2.imread(img_path)
        # get the property of attribute of this img
        height, width = img.shape[0], img.shape[1]
        # figure out the center of the image. shape=(x,y)
        c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)
        # the scale is defined as max edge
        s = max(img.shape[0], img.shape[1]) * 1.0
        # rotate ?
        rot = 0

        flipped = False
        if self.split == 'train':
            if self.cfg.DATASET.RANDOM_CROP: #true
                s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
                w_border = self._get_border(128, img.shape[1])
                h_border = self._get_border(128, img.shape[0])
                c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
                c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
            else:
                # random adjust center and scale
                sf = self.cfg.DATASET.SCALE
                cf = self.cfg.DATASET.SHIFT
                c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
                c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
                s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
            if np.random.random() < self.cfg.DATASET.AUG_ROT:
                rf = self.cfg.DATASET.ROTATE
                rot = np.clip(np.random.randn()*rf, -rf*2, rf*2)

            if np.random.random() < self.cfg.DATASET.FLIP:
                flipped = True
                img = img[:, ::-1, :]
                c[0] = width - c[0] - 1
        
        # calculate the array which make the original image to input format
        trans_input = get_affine_transform(
          c, s, rot, [self.cfg.MODEL.INPUT_RES, self.cfg.MODEL.INPUT_RES])
        # make the original img to input format size
        inp = cv2.warpAffine(img, trans_input, 
                             (self.cfg.MODEL.INPUT_RES, self.cfg.MODEL.INPUT_RES),
                             flags=cv2.INTER_LINEAR)
        # uniformization
        inp = (inp.astype(np.float32) / 255.)
        if self.split == 'train' and not self.cfg.DATASET.NO_COLOR_AUG:
            color_aug(self._data_rng, inp, self._eig_val, self._eig_vec)
        # normalization
        inp = (inp - np.array(self.cfg.DATASET.MEAN).astype(np.float32)) / np.array(self.cfg.DATASET.STD).astype(np.float32)
        # adjust channels order
        inp = inp.transpose(2, 0, 1)

        output_res = self.cfg.MODEL.OUTPUT_RES
        num_joints = self.num_joints
        # calculate an array which make the original image to output size
        trans_output_rot = get_affine_transform(c, s, rot, [output_res, output_res])
        # calculate an array which make the original image to output size rather than input format to output format
        trans_output = get_affine_transform(c, s, 0, [output_res, output_res])
        # calculate an array which make the original segmentation to output size
        trans_seg_output = get_affine_transform(c, s, 0, [output_res, output_res])
        # hm output target
        hm = np.zeros((self.num_classes, output_res, output_res), dtype=np.float32)
        # chekpoint heatmap output target
        hm_hp = np.zeros((num_joints, output_res, output_res), dtype=np.float32)
        dense_kps = np.zeros((num_joints, 2, output_res, output_res), 
                              dtype=np.float32)
        dense_kps_mask = np.zeros((num_joints, output_res, output_res), 
                                   dtype=np.float32)
        # all objects size
        wh = np.zeros((self.max_objs, 2), dtype=np.float32)
        # keypoints offset for center point location in ouput fomat
        kps = np.zeros((self.max_objs, num_joints * 2), dtype=np.float32)
        # offset between centerpoint and centerpoint_init in output format
        reg = np.zeros((self.max_objs, 2), dtype=np.float32)
        # the index of all object center in ouput format
        ind = np.zeros((self.max_objs), dtype=np.int64)
        # mask for real objects,default 32 objects in an image
        reg_mask = np.zeros((self.max_objs), dtype=np.uint8)
        # Keypoints mask for all real keypoints which is visibal
        kps_mask = np.zeros((self.max_objs, self.num_joints * 2), dtype=np.uint8)
        hp_offset = np.zeros((self.max_objs * num_joints, 2), dtype=np.float32)
        # keypoints index in ouput
        hp_ind = np.zeros((self.max_objs * num_joints), dtype=np.int64)
        # similar to kps_mask
        hp_mask = np.zeros((self.max_objs * num_joints), dtype=np.int64)
        # first draw gaussian for keypoints and then for the center point
        draw_gaussian = draw_msra_gaussian if self.cfg.LOSS.MSE_LOSS else \
                        draw_umich_gaussian

        gt_det = []
        for k in range(num_objs):
            ann = anns[k]
            bbox = self._coco_box_to_bbox(ann['bbox'])
            cls_id = int(ann['category_id']) - 1
            pts = np.array(ann['keypoints'], np.float32).reshape(num_joints, 3)
            segment = self.coco.annToMask(ann)      
            if flipped:
                bbox[[0, 2]] = width - bbox[[2, 0]] - 1
                pts[:, 0] = width - pts[:, 0] - 1
                for e in self.flip_idx:
                    pts[e[0]], pts[e[1]] = pts[e[1]].copy(), pts[e[0]].copy()
                segment = segment[:, ::-1]     
             
            bbox[:2] = affine_transform(bbox[:2], trans_output)
            bbox[2:] = affine_transform(bbox[2:], trans_output)
            bbox = np.clip(bbox, 0, output_res - 1)
            segment= cv2.warpAffine(segment, trans_seg_output,
                                 (output_res, output_res),
                                 flags=cv2.INTER_LINEAR)
            segment = segment.astype(np.float32)      
            h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
            if (h > 0 and w > 0) or (rot != 0):
                # figure out gaussian radius
                radius = gaussian_radius((math.ceil(h), math.ceil(w)))
                radius = self.cfg.hm_gauss if self.cfg.LOSS.MSE_LOSS else max(0, int(radius)) #后者
                # work out object center in output format and type is float32
                ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
                # int type for center location of object
                ct_int = ct.astype(np.int32)
                # label w and h for the number k objcet
                wh[k] = 1. * w, 1. * h
                # calculate the index for the center of the k_th object
                ind[k] = ct_int[1] * output_res + ct_int[0]  # object loacation idx
                # calculate the diffience value for float center point and init center point to reduce discretization error
                reg[k] = ct - ct_int       # offset between centerpoint and centerpoint_init
                reg_mask[k] = 1
                           
                #keypoint     
                num_kpts = pts[:, 2].sum()
                if num_kpts == 0:
                    hm[cls_id, ct_int[1], ct_int[0]] = 0.9999
                    reg_mask[k] = 0

                hp_radius = gaussian_radius((math.ceil(h), math.ceil(w)))
                hp_radius = self.cfg.hm_gauss \
                            if self.cfg.LOSS.MSE_LOSS else max(0, int(hp_radius)) 
                for j in range(num_joints):
                    if pts[j, 2] > 0:
                        pts[j, :2] = affine_transform(pts[j, :2], trans_output_rot)
                        if pts[j, 0] >= 0 and pts[j, 0] < output_res and \
                            pts[j, 1] >= 0 and pts[j, 1] < output_res:
                            # offset between keypoints and centerpoint_init
                            kps[k, j * 2: j * 2 + 2] = pts[j, :2] - ct_int
                            kps_mask[k, j * 2: j * 2 + 2] = 1
                            pt_int = pts[j, :2].astype(np.int32)
                            # offset between keypoints and keypoints_init
                            hp_offset[k * num_joints + j] = pts[j, :2] - pt_int
                            hp_ind[k * num_joints + j] = pt_int[1] * output_res + pt_int[0]
                            hp_mask[k * num_joints + j] = 1
                            if self.cfg.LOSS.DENSE_HP:
                                # must be before draw center hm gaussian
                                draw_dense_reg(dense_kps[j], hm[cls_id], ct_int, 
                                               pts[j, :2] - ct_int, radius, is_offset=True)
                                draw_gaussian(dense_kps_mask[j], ct_int, radius)
                            draw_gaussian(hm_hp[j], pt_int, hp_radius)
                draw_gaussian(hm[cls_id], ct_int, radius)
                gt_det.append([ct[0] - w / 2, ct[1] - h / 2, 
                               ct[0] + w / 2, ct[1] + h / 2, 1] + 
                               pts[:, :2].reshape(num_joints * 2).tolist() + [cls_id])
        if rot != 0:
            hm = hm * 0 + 0.9999
            reg_mask *= 0
            kps_mask *= 0
        ret = {'input': inp, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh,
               'hps': kps, 'hps_mask': kps_mask}
        if self.cfg.LOSS.DENSE_HP:
            dense_kps = dense_kps.reshape(num_joints * 2, output_res, output_res)
            dense_kps_mask = dense_kps_mask.reshape(
                num_joints, 1, output_res, output_res)
            dense_kps_mask = np.concatenate([dense_kps_mask, dense_kps_mask], axis=1)
            dense_kps_mask = dense_kps_mask.reshape(
                num_joints * 2, output_res, output_res)
            ret.update({'dense_hps': dense_kps, 'dense_hps_mask': dense_kps_mask})
            del ret['hps'], ret['hps_mask']
        if self.cfg.LOSS.REG_OFFSET:
            ret.update({'reg': reg})
        if self.cfg.LOSS.HM_HP:
            ret.update({'hm_hp': hm_hp})
        if self.cfg.LOSS.REG_HP_OFFSET:
            ret.update({'hp_offset': hp_offset, 'hp_ind': hp_ind, 'hp_mask': hp_mask})
        if self.cfg.DEBUG > 0 or not self.split == 'train':
            gt_det = np.array(gt_det, dtype=np.float32) if len(gt_det) > 0 else \
                   np.zeros((1, 40), dtype=np.float32)
            meta = {'c': c, 's': s, 'gt_det': gt_det, 'img_id': img_id}
            ret['meta'] = meta
        return ret
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201

loss
损失函数包括六个部分,目标中心点的heatmap,中心点的离散误差,目标的宽和高,骨骼点的heatmap,骨骼点的离散误差,骨骼点的偏移量,

class MultiPoseLoss(torch.nn.Module):
    def __init__(self, cfg, local_rank):
        super(MultiPoseLoss, self).__init__()
        self.crit = FocalLoss() # hm
        self.crit_hm_hp = FocalLoss() # hmhp
        self.crit_kp = RegWeightedL1Loss() # keypoints offset
        self.crit_reg = RegL1Loss()     # wh,reg ,hp_offset  
        self.cfg = cfg
        self.local_rank = local_rank

    def forward(self, outputs, batch):
        cfg = self.cfg
        hm_loss, wh_loss, off_loss= 0, 0, 0
        hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0
        hm, wh, hps, reg, hm_hp, hp_offset = outputs

        for s in range(cfg.MODEL.NUM_STACKS):
            hm = _sigmoid(hm)      # (16,1,128,128)
            if cfg.LOSS.HM_HP and not cfg.LOSS.MSE_LOSS:
                hm_hp = _sigmoid(hm_hp)  # (16,17,128,128)
      
            # hm loss is calculate by focal loss
            hm_loss += self.crit(hm, batch['hm']) / cfg.MODEL.NUM_STACKS

            hp_loss += self.crit_kp(hps, batch['hps_mask'],   # hps:(16,34,128,128)
                                    batch['ind'], batch['hps']) / cfg.MODEL.NUM_STACKS
            if cfg.LOSS.WH_WEIGHT > 0:
                # use center index to find center location and find wh to calculate loss
                wh_loss += self.crit_reg(wh, batch['reg_mask'],
                                     batch['ind'], batch['wh']) / cfg.MODEL.NUM_STACKS
            if cfg.LOSS.REG_OFFSET and cfg.LOSS.OFF_WEIGHT > 0: # true
                off_loss += self.crit_reg(reg, batch['reg_mask'],
                                      batch['ind'], batch['reg']) / cfg.MODEL.NUM_STACKS
            if cfg.LOSS.REG_HP_OFFSET and cfg.LOSS.OFF_WEIGHT > 0: # true
                # use keypoints index to calculate keypoints discretization error
                hp_offset_loss += self.crit_reg(
                hp_offset, batch['hp_mask'],
                batch['hp_ind'], batch['hp_offset']) / cfg.MODEL.NUM_STACKS
            if cfg.LOSS.HM_HP and cfg.LOSS.HM_HP_WEIGHT > 0:
                hm_hp_loss += self.crit_hm_hp(
                hm_hp, batch['hm_hp']) / cfg.MODEL.NUM_STACKS
                              
        loss = cfg.LOSS.HM_WEIGHT * hm_loss + cfg.LOSS.WH_WEIGHT * wh_loss + \
               cfg.LOSS.OFF_WEIGHT * off_loss + cfg.LOSS.HP_WEIGHT * hp_loss + \
               cfg.LOSS.HM_HP_WEIGHT * hm_hp_loss + cfg.LOSS.OFF_WEIGHT * hp_offset_loss

        loss_stats = {'loss': loss, 'hm_loss': hm_loss, 'hp_loss': hp_loss, 
                      'hm_hp_loss': hm_hp_loss, 'hp_offset_loss': hp_offset_loss,
                      'wh_loss': wh_loss, 'off_loss': off_loss}
        return loss, loss_stats
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

heatmap损失计算

def _neg_loss(pred, gt):
    ''' Modified focal loss. Exactly the same as CornerNet.
      Runs faster and costs a little bit more memory
    Arguments:
      pred (batch x c x h x w)
      gt_regr (batch x c x h x w)
    '''
    pos_inds = gt.eq(1).float()
    neg_inds = gt.lt(1).float()

    neg_weights = torch.pow(1 - gt, 4)

    loss = 0

    pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
    neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds

    num_pos  = pos_inds.float().sum()
    pos_loss = pos_loss.sum()
    neg_loss = neg_loss.sum()

    if num_pos == 0:
        loss = loss - neg_loss
    else:
        loss = loss - (pos_loss + neg_loss) / num_pos
    return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

eval

topk

先选择出每个类别得分topk的点,然后再把这些点放在一起选出topk

def _topk(scores, K=40):
    batch, cat, height, width = scores.size()
    # select topk values of each category
    topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K)   # topk_inds => batch x cat x K

    topk_inds = topk_inds % (height * width)
    # calculate location for each categories using inds
    topk_ys   = (topk_inds / width).int().float()
    topk_xs   = (topk_inds % width).int().float()
    # select topk of all categories
    topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) # topk_ind => batch x K
    topk_clses = (topk_ind / K).int()
    topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
    topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

post process
1,先由hm找到目标的中心点,然后再找到对应的reg和wh计算出bbox。
2,骨骼点有kps和hmhp共同确定,kps知道自己属于哪个目标但是精度不高,而hmhp精度高,但不知道自己属于哪个个体。

def whole_body_decode(
    heat, wh, kps, seg_feat=None, seg=None, reg=None, hm_hp=None, hp_offset=None, K=100):
    batch, cat, height, width = heat.size()

    num_joints = kps.shape[1] // 2

    # perform nms on heatmaps
    heat = _nms(heat)
    scores, inds, clses, ys, xs = _topk(heat, K=K)

    kps = _transpose_and_gather_feat(kps, inds)
    kps = kps.view(batch, K, num_joints * 2)
    kps[..., ::2] += xs.view(batch, K, 1).expand(batch, K, num_joints)
    kps[..., 1::2] += ys.view(batch, K, 1).expand(batch, K, num_joints)
    if reg is not None:
        reg = _transpose_and_gather_feat(reg, inds)
        reg = reg.view(batch, K, 2)
        xs = xs.view(batch, K, 1) + reg[:, :, 0:1]
        ys = ys.view(batch, K, 1) + reg[:, :, 1:2]
    else:
        xs = xs.view(batch, K, 1) + 0.5
        ys = ys.view(batch, K, 1) + 0.5
    wh = _transpose_and_gather_feat(wh, inds)
    wh = wh.view(batch, K, 2)
    
    weight = _transpose_and_gather_feat(seg, inds)
    ## you can write  (if weight.size(1)!=seg_feat.size(1): 3x3conv  else 1x1conv ) here to select seg conv.
    ## for 3x3
    weight = weight.view([weight.size(1), -1, 3, 3])
    pred_seg = F.conv2d(seg_feat, weight, stride=1, padding=1)
        
    clses  = clses.view(batch, K, 1).float()
    scores = scores.view(batch, K, 1)

    bboxes = torch.cat([xs - wh[..., 0:1] / 2, 
                      ys - wh[..., 1:2] / 2,
                      xs + wh[..., 0:1] / 2, 
                      ys + wh[..., 1:2] / 2], dim=2)
    if hm_hp is not None:
        hm_hp = _nms(hm_hp)
        thresh = 0.1
        kps = kps.view(batch, K, num_joints, 2).permute(
          0, 2, 1, 3).contiguous() #  b x K x 34 => b x J x K x 2
        # reg_kps represent duplicate (b,j,k,1,2) k times is diffierent from duplicate (b,j,1,k,2) k times like hm_kps
        reg_kps = kps.unsqueeze(3).expand(batch, num_joints, K, K, 2)
        # find max scores of each joints(17) and its response index,ys,xs
        hm_score, hm_inds, hm_ys, hm_xs = _topk_channel(hm_hp, K=K) # b x J x K
        # use hp_offset to make the position more precise
        if hp_offset is not None:
            hp_offset = _transpose_and_gather_feat(
              hp_offset, hm_inds.view(batch, -1))
            hp_offset = hp_offset.view(batch, num_joints, K, 2)
            hm_xs = hm_xs + hp_offset[:, :, :, 0]
            hm_ys = hm_ys + hp_offset[:, :, :, 1]
        else:
            hm_xs = hm_xs + 0.5
            hm_ys = hm_ys + 0.5
        # use thresh to make mask
        mask = (hm_score > thresh).float()
        # use mask to select hm_score,hm_ys,hm_xs where hm_score >= thresh
        hm_score = (1 - mask) * -1 + mask * hm_score
        hm_ys = (1 - mask) * (-10000) + mask * hm_ys
        hm_xs = (1 - mask) * (-10000) + mask * hm_xs
        # hm_kps represents the keypoints produced by joint heatmap
        hm_kps = torch.stack([hm_xs, hm_ys], dim=-1).unsqueeze(
          2).expand(batch, num_joints, K, K, 2)
        # figure out the distance between hm_kps and reg_kps
        dist = (((reg_kps - hm_kps) ** 2).sum(dim=4) ** 0.5)
        #
        min_dist, min_ind = dist.min(dim=3) # b x J x K
        hm_score = hm_score.gather(2, min_ind).unsqueeze(-1) # b x J x K x 1
        min_dist = min_dist.unsqueeze(-1)
        min_ind = min_ind.view(batch, num_joints, K, 1, 1).expand(
          batch, num_joints, K, 1, 2)
        hm_kps = hm_kps.gather(3, min_ind)
        hm_kps = hm_kps.view(batch, num_joints, K, 2)
        l = bboxes[:, :, 0].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
        t = bboxes[:, :, 1].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
        r = bboxes[:, :, 2].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
        b = bboxes[:, :, 3].view(batch, 1, K, 1).expand(batch, num_joints, K, 1)
        mask = (hm_kps[..., 0:1] < l) + (hm_kps[..., 0:1] > r) + \
             (hm_kps[..., 1:2] < t) + (hm_kps[..., 1:2] > b) + \
             (hm_score < thresh) + (min_dist > (torch.max(b - t, r - l) * 0.3))
        mask = (mask > 0).float().expand(batch, num_joints, K, 2)
        kps = (1 - mask) * hm_kps + mask * kps
        kps = kps.permute(0, 2, 1, 3).contiguous().view(
          batch, K, num_joints * 2)
    detections = torch.cat([bboxes, scores, kps, torch.transpose(hm_score.squeeze(dim=3), 1, 2)], dim=2)
    return detections, pred_seg
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/342780
推荐阅读
  

闽ICP备14008679号