We introduce two structures for deep layer aggregation (DLA): iterative deep aggrega-
tion (IDA) and hierarchical deep aggregation (HDA).
Hierarchical deep aggregation merges blocks and stages in a tree to preserve and combine feature channels.
我们介绍两种结构深层聚合(DLA):迭代深层聚合 (IDA)和层次深度聚合(HDA)。
IDA focuses on fusing resolutions and scales while HDA focuses on merging features from all modules and channels.
class BasicBlock(nn.Module): def __init__(self, inplanes, planes, stride=1, dilation=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation) self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) self.stride = stride def forward(self, x, residual=None): if residual is None: residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out += residual out = self.relu(out) return out
class Root(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, residual): super(Root, self).__init__() self.conv = nn.Conv2d( in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2) self.bn = nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.residual = residual def forward(self, *x): children = x x = self.conv(torch.cat(x, 1)) x = self.bn(x) if self.residual: x += children[0] x = self.relu(x) return x class Tree(nn.Module): def __init__(self, levels, block, in_channels, out_channels, stride=1, level_root=False, root_dim=0, root_kernel_size=1, dilation=1, root_residual=False): super(Tree, self).__init__() if root_dim == 0: root_dim = 2 * out_channels if level_root: root_dim += in_channels if levels == 1: self.tree1 = block(in_channels, out_channels, stride, dilation=dilation) self.tree2 = block(out_channels, out_channels, 1, dilation=dilation) else: self.tree1 = Tree(levels - 1, block, in_channels, out_channels, stride, root_dim=0, root_kernel_size=root_kernel_size, dilation=dilation, root_residual=root_residual) self.tree2 = Tree(levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, root_kernel_size=root_kernel_size, dilation=dilation, root_residual=root_residual) if levels == 1: self.root = Root(root_dim, out_channels, root_kernel_size, root_residual) self.level_root = level_root self.root_dim = root_dim self.downsample = None self.project = None self.levels = levels if stride > 1: self.downsample = nn.MaxPool2d(stride, stride=stride) if in_channels != out_channels: self.project = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False), nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM) ) def forward(self, x, residual=None, children=None): children = [] if children is None else children bottom = self.downsample(x) if self.downsample else x residual = self.project(bottom) if self.project else bottom if self.level_root: children.append(bottom) x1 = self.tree1(x, residual) if self.levels == 1: x2 = self.tree2(x1) x = self.root(x2, x1, *children) else: children.append(x1) x = self.tree2(x1, children=children) return x
class IDAUp(nn.Module): def __init__(self, o, channels, up_f): super(IDAUp, self).__init__() for i in range(1, len(channels)): c = channels[i] f = int(up_f[i]) proj = DeformConv(c, o) node = DeformConv(o, o) up = nn.ConvTranspose2d(o, o, f * 2, stride=f, padding=f // 2, output_padding=0, groups=o, bias=False) fill_up_weights(up) setattr(self, 'proj_' + str(i), proj) setattr(self, 'up_' + str(i), up) setattr(self, 'node_' + str(i), node) def forward(self, layers, startp, endp): for i in range(startp + 1, endp): upsample = getattr(self, 'up_' + str(i - startp)) project = getattr(self, 'proj_' + str(i - startp)) layers[i] = upsample(project(layers[i])) node = getattr(self, 'node_' + str(i - startp)) layers[i] = node(layers[i] + layers[i - 1]) class DLAUp(nn.Module): def __init__(self, startp, channels, scales, in_channels=None): super(DLAUp, self).__init__() self.startp = startp if in_channels is None: in_channels = channels self.channels = channels channels = list(channels) scales = np.array(scales, dtype=int) for i in range(len(channels) - 1): j = -i - 2 setattr(self, 'ida_{}'.format(i), IDAUp(channels[j], in_channels[j:], scales[j:] // scales[j])) scales[j + 1:] = scales[j] in_channels[j + 1:] = [channels[j] for _ in channels[j + 1:]] def forward(self, layers): out = [layers[-1]] # start with 32 for i in range(len(layers) - self.startp - 1): ida = getattr(self, 'ida_{}'.format(i)) ida(layers, len(layers) -i - 2, len(layers)) out.insert(0, layers[-1]) return out
class DLASeg(nn.Module): def __init__(self, base_name, pretrained, down_ratio, final_kernel, last_level, out_channel=0): super(DLASeg, self).__init__() assert down_ratio in [2, 4, 8, 16] self.first_level = int(np.log2(down_ratio)) # down_ratio=4 self.last_level = last_level self.base = globals()[base_name](pretrained=pretrained) channels = self.base.channels scales = [2 ** i for i in range(len(channels[self.first_level:]))] self.dla_up = DLAUp(self.first_level, channels[self.first_level:], scales) if out_channel == 0: out_channel = channels[self.first_level] self.ida_up = IDAUp(out_channel, channels[self.first_level:self.last_level], [2 ** i for i in range(self.last_level - self.first_level)]) def forward(self, x): x = self.base(x) x = self.dla_up(x) y = [] for i in range(self.last_level - self.first_level): y.append(x[i].clone()) self.ida_up(y, 0, len(y)) x = y[-1] return x
class KeypointHead(nn.Module): def __init__(self, intermediate_channel, head_conv): super(KeypointHead, self).__init__() self.hm = nn.Sequential( nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(head_conv, 1, kernel_size=1, stride=1, padding=0)) self.wh = nn.Sequential( nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0)) self.hps = nn.Sequential( nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(head_conv, 34, kernel_size=1, stride=1, padding=0)) self.reg = nn.Sequential( nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0)) self.hm_hp = nn.Sequential( nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(head_conv, 17, kernel_size=1, stride=1, padding=0)) self.hp_offset = nn.Sequential( nn.Conv2d(intermediate_channel, head_conv, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), nn.Conv2d(head_conv, 2, kernel_size=1, stride=1, padding=0)) self.init_weights() def forward(self, x): return [self.hm(x), self.wh(x), self.hps(x), self.reg(x), self.hm_hp(x), self.hp_offset(x)]
groundtruth根据输出大小进行仿射变换后得到新的bounding box坐标点,该bounding box计算目标的中心点为正样本点,其他位置都是负样本。
def __getitem__(self, index): # get img_id through index img_id = self.images[index] # get img_name by img_id file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name'] # get img_path by combining dataset_path and img_name img_path = os.path.join(self.img_dir, file_name) # get all annotation_ids through img_id ann_ids = self.coco.getAnnIds(imgIds=[img_id]) # get all annotations through ann_ids anns = self.coco.loadAnns(ids=ann_ids) # select annotions which category_id in self._valid_ids and is not crowd labeled anns = list(filter(lambda x:x['category_id'] in self._valid_ids and x['iscrowd']!= 1 , anns)) # limit the numbers of objects in an image num_objs = min(len(anns), self.max_objs) # read the image img = cv2.imread(img_path) # get the property of attribute of this img height, width = img.shape[0], img.shape[1] # figure out the center of the image. shape=(x,y) c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32) # the scale is defined as max edge s = max(img.shape[0], img.shape[1]) * 1.0 # rotate ? rot = 0 flipped = False if self.split == 'train': if self.cfg.DATASET.RANDOM_CROP: #true s = s * np.random.choice(np.arange(0.6, 1.4, 0.1)) w_border = self._get_border(128, img.shape[1]) h_border = self._get_border(128, img.shape[0]) c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border) c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border) else: # random adjust center and scale sf = self.cfg.DATASET.SCALE cf = self.cfg.DATASET.SHIFT c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf) c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf) s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf) if np.random.random() < self.cfg.DATASET.AUG_ROT: rf = self.cfg.DATASET.ROTATE rot = np.clip(np.random.randn()*rf, -rf*2, rf*2) if np.random.random() < self.cfg.DATASET.FLIP: flipped = True img = img[:, ::-1, :] c[0] = width - c[0] - 1 # calculate the array which make the original image to input format trans_input = get_affine_transform( c, s, rot, [self.cfg.MODEL.INPUT_RES, self.cfg.MODEL.INPUT_RES]) # make the original img to input format size inp = cv2.warpAffine(img, trans_input, (self.cfg.MODEL.INPUT_RES, self.cfg.MODEL.INPUT_RES), flags=cv2.INTER_LINEAR) # uniformization inp = (inp.astype(np.float32) / 255.) if self.split == 'train' and not self.cfg.DATASET.NO_COLOR_AUG: color_aug(self._data_rng, inp, self._eig_val, self._eig_vec) # normalization inp = (inp - np.array(self.cfg.DATASET.MEAN).astype(np.float32)) / np.array(self.cfg.DATASET.STD).astype(np.float32) # adjust channels order inp = inp.transpose(2, 0, 1) output_res = self.cfg.MODEL.OUTPUT_RES num_joints = self.num_joints # calculate an array which make the original image to output size trans_output_rot = get_affine_transform(c, s, rot, [output_res, output_res]) # calculate an array which make the original image to output size rather than input format to output format trans_output = get_affine_transform(c, s, 0, [output_res, output_res]) # calculate an array which make the original segmentation to output size trans_seg_output = get_affine_transform(c, s, 0, [output_res, output_res]) # hm output target hm = np.zeros((self.num_classes, output_res, output_res), dtype=np.float32) # chekpoint heatmap output target hm_hp = np.zeros((num_joints, output_res, output_res), dtype=np.float32) dense_kps = np.zeros((num_joints, 2, output_res, output_res), dtype=np.float32) dense_kps_mask = np.zeros((num_joints, output_res, output_res), dtype=np.float32) # all objects size wh = np.zeros((self.max_objs, 2), dtype=np.float32) # keypoints offset for center point location in ouput fomat kps = np.zeros((self.max_objs, num_joints * 2), dtype=np.float32) # offset between centerpoint and centerpoint_init in output format reg = np.zeros((self.max_objs, 2), dtype=np.float32) # the index of all object center in ouput format ind = np.zeros((self.max_objs), dtype=np.int64) # mask for real objects,default 32 objects in an image reg_mask = np.zeros((self.max_objs), dtype=np.uint8) # Keypoints mask for all real keypoints which is visibal kps_mask = np.zeros((self.max_objs, self.num_joints * 2), dtype=np.uint8) hp_offset = np.zeros((self.max_objs * num_joints, 2), dtype=np.float32) # keypoints index in ouput hp_ind = np.zeros((self.max_objs * num_joints), dtype=np.int64) # similar to kps_mask hp_mask = np.zeros((self.max_objs * num_joints), dtype=np.int64) # first draw gaussian for keypoints and then for the center point draw_gaussian = draw_msra_gaussian if self.cfg.LOSS.MSE_LOSS else \ draw_umich_gaussian gt_det = [] for k in range(num_objs): ann = anns[k] bbox = self._coco_box_to_bbox(ann['bbox']) cls_id = int(ann['category_id']) - 1 pts = np.array(ann['keypoints'], np.float32).reshape(num_joints, 3) segment = self.coco.annToMask(ann) if flipped: bbox[[0, 2]] = width - bbox[[2, 0]] - 1 pts[:, 0] = width - pts[:, 0] - 1 for e in self.flip_idx: pts[e[0]], pts[e[1]] = pts[e[1]].copy(), pts[e[0]].copy() segment = segment[:, ::-1] bbox[:2] = affine_transform(bbox[:2], trans_output) bbox[2:] = affine_transform(bbox[2:], trans_output) bbox = np.clip(bbox, 0, output_res - 1) segment= cv2.warpAffine(segment, trans_seg_output, (output_res, output_res), flags=cv2.INTER_LINEAR) segment = segment.astype(np.float32) h, w = bbox[3] - bbox[1], bbox[2] - bbox[0] if (h > 0 and w > 0) or (rot != 0): # figure out gaussian radius radius = gaussian_radius((math.ceil(h), math.ceil(w))) radius = self.cfg.hm_gauss if self.cfg.LOSS.MSE_LOSS else max(0, int(radius)) #后者 # work out object center in output format and type is float32 ct = np.array([(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32) # int type for center location of object ct_int = ct.astype(np.int32) # label w and h for the number k objcet wh[k] = 1. * w, 1. * h # calculate the index for the center of the k_th object ind[k] = ct_int[1] * output_res + ct_int[0] # object loacation idx # calculate the diffience value for float center point and init center point to reduce discretization error reg[k] = ct - ct_int # offset between centerpoint and centerpoint_init reg_mask[k] = 1 #keypoint num_kpts = pts[:, 2].sum() if num_kpts == 0: hm[cls_id, ct_int[1], ct_int[0]] = 0.9999 reg_mask[k] = 0 hp_radius = gaussian_radius((math.ceil(h), math.ceil(w))) hp_radius = self.cfg.hm_gauss \ if self.cfg.LOSS.MSE_LOSS else max(0, int(hp_radius)) for j in range(num_joints): if pts[j, 2] > 0: pts[j, :2] = affine_transform(pts[j, :2], trans_output_rot) if pts[j, 0] >= 0 and pts[j, 0] < output_res and \ pts[j, 1] >= 0 and pts[j, 1] < output_res: # offset between keypoints and centerpoint_init kps[k, j * 2: j * 2 + 2] = pts[j, :2] - ct_int kps_mask[k, j * 2: j * 2 + 2] = 1 pt_int = pts[j, :2].astype(np.int32) # offset between keypoints and keypoints_init hp_offset[k * num_joints + j] = pts[j, :2] - pt_int hp_ind[k * num_joints + j] = pt_int[1] * output_res + pt_int[0] hp_mask[k * num_joints + j] = 1 if self.cfg.LOSS.DENSE_HP: # must be before draw center hm gaussian draw_dense_reg(dense_kps[j], hm[cls_id], ct_int, pts[j, :2] - ct_int, radius, is_offset=True) draw_gaussian(dense_kps_mask[j], ct_int, radius) draw_gaussian(hm_hp[j], pt_int, hp_radius) draw_gaussian(hm[cls_id], ct_int, radius) gt_det.append([ct[0] - w / 2, ct[1] - h / 2, ct[0] + w / 2, ct[1] + h / 2, 1] + pts[:, :2].reshape(num_joints * 2).tolist() + [cls_id]) if rot != 0: hm = hm * 0 + 0.9999 reg_mask *= 0 kps_mask *= 0 ret = {'input': inp, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh, 'hps': kps, 'hps_mask': kps_mask} if self.cfg.LOSS.DENSE_HP: dense_kps = dense_kps.reshape(num_joints * 2, output_res, output_res) dense_kps_mask = dense_kps_mask.reshape( num_joints, 1, output_res, output_res) dense_kps_mask = np.concatenate([dense_kps_mask, dense_kps_mask], axis=1) dense_kps_mask = dense_kps_mask.reshape( num_joints * 2, output_res, output_res) ret.update({'dense_hps': dense_kps, 'dense_hps_mask': dense_kps_mask}) del ret['hps'], ret['hps_mask'] if self.cfg.LOSS.REG_OFFSET: ret.update({'reg': reg}) if self.cfg.LOSS.HM_HP: ret.update({'hm_hp': hm_hp}) if self.cfg.LOSS.REG_HP_OFFSET: ret.update({'hp_offset': hp_offset, 'hp_ind': hp_ind, 'hp_mask': hp_mask}) if self.cfg.DEBUG > 0 or not self.split == 'train': gt_det = np.array(gt_det, dtype=np.float32) if len(gt_det) > 0 else \ np.zeros((1, 40), dtype=np.float32) meta = {'c': c, 's': s, 'gt_det': gt_det, 'img_id': img_id} ret['meta'] = meta return ret
class MultiPoseLoss(torch.nn.Module): def __init__(self, cfg, local_rank): super(MultiPoseLoss, self).__init__() self.crit = FocalLoss() # hm self.crit_hm_hp = FocalLoss() # hmhp self.crit_kp = RegWeightedL1Loss() # keypoints offset self.crit_reg = RegL1Loss() # wh,reg ,hp_offset self.cfg = cfg self.local_rank = local_rank def forward(self, outputs, batch): cfg = self.cfg hm_loss, wh_loss, off_loss= 0, 0, 0 hp_loss, off_loss, hm_hp_loss, hp_offset_loss = 0, 0, 0, 0 hm, wh, hps, reg, hm_hp, hp_offset = outputs for s in range(cfg.MODEL.NUM_STACKS): hm = _sigmoid(hm) # (16,1,128,128) if cfg.LOSS.HM_HP and not cfg.LOSS.MSE_LOSS: hm_hp = _sigmoid(hm_hp) # (16,17,128,128) # hm loss is calculate by focal loss hm_loss += self.crit(hm, batch['hm']) / cfg.MODEL.NUM_STACKS hp_loss += self.crit_kp(hps, batch['hps_mask'], # hps:(16,34,128,128) batch['ind'], batch['hps']) / cfg.MODEL.NUM_STACKS if cfg.LOSS.WH_WEIGHT > 0: # use center index to find center location and find wh to calculate loss wh_loss += self.crit_reg(wh, batch['reg_mask'], batch['ind'], batch['wh']) / cfg.MODEL.NUM_STACKS if cfg.LOSS.REG_OFFSET and cfg.LOSS.OFF_WEIGHT > 0: # true off_loss += self.crit_reg(reg, batch['reg_mask'], batch['ind'], batch['reg']) / cfg.MODEL.NUM_STACKS if cfg.LOSS.REG_HP_OFFSET and cfg.LOSS.OFF_WEIGHT > 0: # true # use keypoints index to calculate keypoints discretization error hp_offset_loss += self.crit_reg( hp_offset, batch['hp_mask'], batch['hp_ind'], batch['hp_offset']) / cfg.MODEL.NUM_STACKS if cfg.LOSS.HM_HP and cfg.LOSS.HM_HP_WEIGHT > 0: hm_hp_loss += self.crit_hm_hp( hm_hp, batch['hm_hp']) / cfg.MODEL.NUM_STACKS loss = cfg.LOSS.HM_WEIGHT * hm_loss + cfg.LOSS.WH_WEIGHT * wh_loss + \ cfg.LOSS.OFF_WEIGHT * off_loss + cfg.LOSS.HP_WEIGHT * hp_loss + \ cfg.LOSS.HM_HP_WEIGHT * hm_hp_loss + cfg.LOSS.OFF_WEIGHT * hp_offset_loss loss_stats = {'loss': loss, 'hm_loss': hm_loss, 'hp_loss': hp_loss, 'hm_hp_loss': hm_hp_loss, 'hp_offset_loss': hp_offset_loss, 'wh_loss': wh_loss, 'off_loss': off_loss} return loss, loss_stats
def _neg_loss(pred, gt): ''' Modified focal loss. Exactly the same as CornerNet. Runs faster and costs a little bit more memory Arguments: pred (batch x c x h x w) gt_regr (batch x c x h x w) ''' pos_inds = gt.eq(1).float() neg_inds = gt.lt(1).float() neg_weights = torch.pow(1 - gt, 4) loss = 0 pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() if num_pos == 0: loss = loss - neg_loss else: loss = loss - (pos_loss + neg_loss) / num_pos return loss
def _topk(scores, K=40):
batch, cat, height, width = scores.size()
# select topk values of each category
topk_scores, topk_inds = torch.topk(scores.view(batch, cat, -1), K) # topk_inds => batch x cat x K
topk_inds = topk_inds % (height * width)
# calculate location for each categories using inds
topk_ys = (topk_inds / width).int().float()
topk_xs = (topk_inds % width).int().float()
# select topk of all categories
topk_score, topk_ind = torch.topk(topk_scores.view(batch, -1), K) # topk_ind => batch x K
topk_clses = (topk_ind / K).int()
topk_inds = _gather_feat(topk_inds.view(batch, -1, 1), topk_ind).view(batch, K)
topk_ys = _gather_feat(topk_ys.view(batch, -1, 1), topk_ind).view(batch, K)
topk_xs = _gather_feat(topk_xs.view(batch, -1, 1), topk_ind).view(batch, K)
post process
def whole_body_decode( heat, wh, kps, seg_feat=None, seg=None, reg=None, hm_hp=None, hp_offset=None, K=100): batch, cat, height, width = heat.size() num_joints = kps.shape[1] // 2 # perform nms on heatmaps heat = _nms(heat) scores, inds, clses, ys, xs = _topk(heat, K=K) kps = _transpose_and_gather_feat(kps, inds) kps = kps.view(batch, K, num_joints * 2) kps[..., ::2] += xs.view(batch, K, 1).expand(batch, K, num_joints) kps[..., 1::2] += ys.view(batch, K, 1).expand(batch, K, num_joints) if reg is not None: reg = _transpose_and_gather_feat(reg, inds) reg = reg.view(batch, K, 2) xs = xs.view(batch, K, 1) + reg[:, :, 0:1] ys = ys.view(batch, K, 1) + reg[:, :, 1:2] else: xs = xs.view(batch, K, 1) + 0.5 ys = ys.view(batch, K, 1) + 0.5 wh = _transpose_and_gather_feat(wh, inds) wh = wh.view(batch, K, 2) weight = _transpose_and_gather_feat(seg, inds) ## you can write (if weight.size(1)!=seg_feat.size(1): 3x3conv else 1x1conv ) here to select seg conv. ## for 3x3 weight = weight.view([weight.size(1), -1, 3, 3]) pred_seg = F.conv2d(seg_feat, weight, stride=1, padding=1) clses = clses.view(batch, K, 1).float() scores = scores.view(batch, K, 1) bboxes = torch.cat([xs - wh[..., 0:1] / 2, ys - wh[..., 1:2] / 2, xs + wh[..., 0:1] / 2, ys + wh[..., 1:2] / 2], dim=2) if hm_hp is not None: hm_hp = _nms(hm_hp) thresh = 0.1 kps = kps.view(batch, K, num_joints, 2).permute( 0, 2, 1, 3).contiguous() # b x K x 34 => b x J x K x 2 # reg_kps represent duplicate (b,j,k,1,2) k times is diffierent from duplicate (b,j,1,k,2) k times like hm_kps reg_kps = kps.unsqueeze(3).expand(batch, num_joints, K, K, 2) # find max scores of each joints(17) and its response index,ys,xs hm_score, hm_inds, hm_ys, hm_xs = _topk_channel(hm_hp, K=K) # b x J x K # use hp_offset to make the position more precise if hp_offset is not None: hp_offset = _transpose_and_gather_feat( hp_offset, hm_inds.view(batch, -1)) hp_offset = hp_offset.view(batch, num_joints, K, 2) hm_xs = hm_xs + hp_offset[:, :, :, 0] hm_ys = hm_ys + hp_offset[:, :, :, 1] else: hm_xs = hm_xs + 0.5 hm_ys = hm_ys + 0.5 # use thresh to make mask mask = (hm_score > thresh).float() # use mask to select hm_score,hm_ys,hm_xs where hm_score >= thresh hm_score = (1 - mask) * -1 + mask * hm_score hm_ys = (1 - mask) * (-10000) + mask * hm_ys hm_xs = (1 - mask) * (-10000) + mask * hm_xs # hm_kps represents the keypoints produced by joint heatmap hm_kps = torch.stack([hm_xs, hm_ys], dim=-1).unsqueeze( 2).expand(batch, num_joints, K, K, 2) # figure out the distance between hm_kps and reg_kps dist = (((reg_kps - hm_kps) ** 2).sum(dim=4) ** 0.5) # min_dist, min_ind = dist.min(dim=3) # b x J x K hm_score = hm_score.gather(2, min_ind).unsqueeze(-1) # b x J x K x 1 min_dist = min_dist.unsqueeze(-1) min_ind = min_ind.view(batch, num_joints, K, 1, 1).expand( batch, num_joints, K, 1, 2) hm_kps = hm_kps.gather(3, min_ind) hm_kps = hm_kps.view(batch, num_joints, K, 2) l = bboxes[:, :, 0].view(batch, 1, K, 1).expand(batch, num_joints, K, 1) t = bboxes[:, :, 1].view(batch, 1, K, 1).expand(batch, num_joints, K, 1) r = bboxes[:, :, 2].view(batch, 1, K, 1).expand(batch, num_joints, K, 1) b = bboxes[:, :, 3].view(batch, 1, K, 1).expand(batch, num_joints, K, 1) mask = (hm_kps[..., 0:1] < l) + (hm_kps[..., 0:1] > r) + \ (hm_kps[..., 1:2] < t) + (hm_kps[..., 1:2] > b) + \ (hm_score < thresh) + (min_dist > (torch.max(b - t, r - l) * 0.3)) mask = (mask > 0).float().expand(batch, num_joints, K, 2) kps = (1 - mask) * hm_kps + mask * kps kps = kps.permute(0, 2, 1, 3).contiguous().view( batch, K, num_joints * 2) detections = torch.cat([bboxes, scores, kps, torch.transpose(hm_score.squeeze(dim=3), 1, 2)], dim=2) return detections, pred_seg
