当前位置:   article > 正文

论文解读:Fine-grained Visual Classification with High-temperature Refinement and Background Suppression

fine-grained visual classification with high-temperature refinement and back

《Fine-grained Visual Classification with High-temperature Refinement and Background Suppression》

摘要

细粒度的视觉分类是一项具有挑战性的任务,因为不同类别之间存在着高度相似性,同时同一类别内的数据也存在着明显差异。为了解决这些挑战,本文提出了一种新的网络模型,称为“高温细化和背景抑制”(HERBS),该模型由高温细化模块和背景抑制模块组成。高温细化模块允许网络学习适当的特征尺度,并提高各种特征的表征能力,而背景抑制模块则利用分类置信度将特征图分为前景和背景,并抑制低置信度区域中的特征值,从而提高了模型的判别能力

实验结果表明,HERBS模型有效地融合了不同尺度的特征,抑制了背景噪声,以及适当尺度的判别特征对细粒度视觉分类任务具有重要意义。该方法在CUB-200-2011和NABirds数据集上均实现了超过93%的准确率,为解决细粒度视觉分类问题提供了有效的解决方案。

细粒度视觉分类综述

细粒度视觉分类(FGVC)是计算机视觉中的一项具有挑战性的任务,涉及将图像分类为非常具体和详细的​​类别,例如不同种类的鸟类、狗、车辆模型和医学图像。

如下图所示,这四种麻雀看上去几乎一模一样,但从不同的角度看,同一种麻雀看起来也有很大不同。与涉及识别“动物”或“车辆”等宽泛类别的粗粒度分类相反,细粒度分类需要能够识别视觉特征的细微差异,例如颜色、纹理、形状和图案,这些特征通常存在于小区域。这些区域被称为判别区域或前景区域。

在这里插入图片描述

由于深度卷积网络能够学习到非常鲁棒的图像特征表示,对图像进行细粒度分类的方法,大多都是以深度卷积网络为基础的,这些方法大致可以分为以下四个方向:

1.基于常规图像分类网络的微调方法:

这种方法使用预先训练好的深度卷积网络(如在 ImageNet 上训练的模型)作为基础,然后通过微调(fine-tuning)来适应特定的细粒度分类任务。微调通常涉及冻结预训练网络的底层,并在顶层进行训练以适应新数据集。这种方法简单直接,并且通常在训练数据较少的情况下表现良好。

2.基于细粒度特征学习的方法:

这种方法专注于从数据中学习特定类别的细粒度特征。这可能涉及使用卷积网络的中间层特征来捕获更细粒度的信息,或者设计网络结构以更好地区分类别之间的微小差异。通常,这种方法需要更多的数据和计算资源来训练,但在处理复杂的细粒度分类任务时效果往往更好。

3.基于目标块的检测和对齐方法:

这种方法首先通过目标检测算法定位图像中的感兴趣区域(ROI),然后对每个ROI进行对齐和细粒度分类。对齐是指将目标区域与周围环境对齐,以减少背景干扰并突出细微特征。这种方法通常需要明确定位目标并进行有效的对齐,但可以在处理具有复杂背景和多个对象的图像时表现良好。

  1. 基于视觉注意机制的方法:

这种方法受到人类视觉系统的启发,通过模拟注意力机制来选择和聚焦于图像中最重要的区域。这些关注区域可以包含有助于区分类别的重要细节。这种方法可以在网络中引入注意力模块,使网络能够自动学习在分类时关注哪些区域,并且在处理大型图像数据库时具有很强的可扩展性。

本文介绍的HERBS模型通过引入高温细化和背景抑制模块,有效地提高了细粒度视觉分类任务的性能。高温细化模块允许网络学习适当的特征尺度,而背景抑制模块则通过抑制背景噪声提高了模型的判别能力。这些方法有效地解决了传统方法中存在的一些问题,提高了模型的鲁棒性和分类准确性。。

相关工作

细粒度视觉分类

在 FGVC 领域,有两种从细微区域提取判别特征的方法,大致分为基于对象部分的方法和基于注意力的方法。

基于对象部分的方法旨在通过使用模型生成候选区域来找到用于识别的对象局部区域,然后从中提取区分特征。 MA-CNN通过将特征图聚类到对象部分来同时训练定位和分类精度。这种无监督分类通过将模式划分为对象部分来增强特征学习。该方法允许同时学习有区别的特征和位置。 S3N在特征图上找到每个类别响应的局部极值以增强特征。此外,WS-DAN通过消除局部极值来增强数据,以发现其他判别性特征。

另一方面,基于注意力的方法使用注意力机制来增强特征学习并定位对象细节。 MAMC生成多组通过注意力机制增强的特征,Cross-X[23]使用多激励模型的注意力图来学习不同类别的特征。 API-Net[52]和PCA-Net[44]使用两个图像作为输入来计算特征图之间的注意力,以增强判别性表示。 CAP[1]计算输出特征的自注意力图来表达特征像素之间的关系,SR-GNN[2]使用图卷积神经网络来描述部分之间的关​​系。 CAL[25]在注意力图中添加了反事实干预来预测类别。随着Transformer[32]在计算机视觉领域的发展,许多改进的Vision Transformer架构被提出,例如FFVT[35]、SIM-Trans[29]、TransFG[9]和AFTrans[45],这些方法利用转换器层中的自注意力图来增强特征学习并定位对象细节。

方法

HERBS网络结构如下:
在这里插入图片描述

在图中,所提出的高温细化和背景抑制(HERBS)网络由主干网、自上而下的特征融合模块、自下而上的特征融合模块和HERBS组成。主干可以是基于 Transformer 的模型(例如 Swin Transformer)或基于卷积的模型(例如 ResNet)。自上而下和自下而上的特征融合模块类似于路径聚合网络(PA),可以将其视为具有附加自下而上路径的特征金字塔网络(FPN)。

下图为HERBS拆开二个模块的最简答示意图:
在这里插入图片描述

其中PANet的整体结构:
在这里插入图片描述
PANet的网络结构如上图所示,它由5个核心模块组成。

其中(a)是一个FPN,(b)是PAN增加的自底向上的特征融合层,(c)是自适应特征池化层,(d)是PANet的bounding box预测头,(e)是用于预测掩码的全连接融合层。

背景抑制(BS)

背景抑制(Background Suppression,BS)模块的设计旨在从特征图中提取判别性特征,并抑制不重要的背景信息。

在这里插入图片描述

该模块步骤按照上图的结构来看首先

生成分类图: 从主干网络的特征图生成分类图,其中每个位置对应于不同类别的得分。公式如下:

Y i = W i h s i + b i ( 1 ) Y_i = W_ihs_i +b_i(1) Yi=Wihsi+bi(1)

其中:

  • Y i Y_i Yi是分类图,维度用 C g t × H i × W i C_{gt} \times H_i \times W_i Cgt×Hi×Wi,表示第 i i i分类得分图
  • W i W_i Wi是第 i i i层的分类器权重。
  • b i b_i bi是第 i i i层的偏差。
  • h s i hs_i hsi是由第 i i i层主干块生成的特征图 C i × H i × W i C_i \times H_i \times W_i Ci×Hi×Wi
  • C g t C_{gt} Cgt是目标类别的数量
  • H i H_i Hi W i W_i Wi分别表示特征图的高度和宽度。

选择前景特征: 根据分类图选择预测得分最高的一组特征,以聚合并提取判别性特征。

P m a x , i = m a x ( S o f t m a x ( Y i ) ) ( 2 ) P_{max,i}=max(Softmax(Y_i))(2) Pmax,i=max(Softmax(Yi))(2)
其中 P m a x . i P_{max.i} Pmax.i表示第 i i i层的最大分数图。接下来,选择所有预测中得分最高的 K i K_i Ki特征。

合并特征预测: 使用图卷积模块合并所选特征,并基于合并的特征进行分类预测。这一步旨在增强所选区域的判别特征。

在这个阶段,BS模块具有未选择的分类图,称为丢弃图,表示为 Y d Y_d Yd ,如下图所示
在这里插入图片描述
剩下有特征的图合并进行分类预测,表示为 Y m Y_m Ym。如下图黄色的部分

在这里插入图片描述
合并分类预测的目标函数是标准分类目标函数,使用交叉熵来计算预测分布 P m P_m Pm与地面真实标签 y y y之间的相似度。合并损失计算如下:

P m = S o f t m a x ( Y m ) ( 3 ) P_m = Softmax(Y_m)(3) Pm=Softmax(Ym)(3)
l o s s m = − ∑ c i = 1 c g t y c i l o g ( P m , c i ) ( 4 ) loss_m=-\sum_{ci=1}^{c_{gt}}y_cilog(P_m,ci)(4) lossm=ci=1cgtycilog(Pm,ci)(4)
这里, y c i y_{ci} yci i t h i^{th} ith 类的基本事实, P m , c i P_{m,ci} Pm,ci i t h i^{th} ith类的预测概率。对目标类别 C g t C_{gt} Cgt的数量执行求和。这增强了所选区域中的辨别特征。

抑制背景特征:BS 模块的另一个目标是抑制丢弃的地图中的特征并增加前景和背景之间的间隙。对未选择的特征图应用双曲正切函数,以抑制背景信息并增强前景特征的对比度。

具体公式如下:

P d = t a n h ( Y d ) ( 5 ) P_d=tanh(Y_d)(5) Pd=tanh(Yd)(5)

其中 P d P_d Pd 表示抑制后的特征图,维度与 Y d Y_d Yd相同。

丢弃的损失 l o s s d loss_d lossd计算为预测与伪目标 -1 之间的均方误差,如方程(6)中所定义:

l o s s d = ∑ i = 1 c g t ( P d , c i + 1 ) 2 ( 6 ) loss_d=\sum_{i=1}^{c_{gt}}(P_{d,ci}+1)^2(6) lossd=i=1cgt(Pd,ci+1)2(6)

为了防止所有块的特征图仅在相同位置有特征响应,还将每层的预测合并到训练目标中平均下来计算,如下所示:
P l i = S o f t m a x ( W i ( A v g p o o l ( h s i ) ) + b i ) ( 7 ) P_{li}=Softmax(W_i(Avgpool(hs_i))+b_i)(7) Pli=Softmax(Wi(Avgpool(hsi))+bi)(7)
l o s s l = − ∑ i = 1 n ∑ c i = 1 C g t y c i l o g ( P l i , c i ) ( 8 ) loss_l=-\sum_{i=1}^n\sum_{ci=1}^{C_{gt}}y_{ci}log(P_{li,ci})(8) lossl=i=1nci=1Cgtycilog(Pli,ci)(8)
其中 Avgpool 函数聚合每个通道的所有 H i H_i Hi W i W_i Wi ,主干中的块数由 n n n表示。

损失函数设计总的 BS 目标由合并损失 ( l o s s m loss_m lossm)、丢弃损失 ( l o s s d loss_d lossd ) 和平均层损失 ( l o s s l loss_l lossl ) 的加权和给出,

如式(9)所示:
l o s s b s = λ m l o s s m + λ d l o s s d + λ l l o s s i ( 9 ) loss_{bs}=\lambda_mloss_m+\lambda_dloss_d+\lambda_lloss_i(9) lossbs=λmlossm+λdlossd+λllossi(9)
其中 λ m 、 λ d \lambda_m、\lambda_d λmλd λ l \lambda_l λl分别是合并损失、丢弃损失和平均层损失的权重。具体来说,我们将 λ m \lambda_m λm设置为 1, λ d \lambda_d λd设置为 5, λ l \lambda_l λl 设置为 0.3。这些值的设置是为了平衡前景和背景损失,并根据前三个时期的训练损失确定。

高温精炼细化

在高温细化模块中,目标是通过使自上而下的分类器 K 1 K_1 K1学习自下而上的分类器 K 2 K_2 K2 的输出分布来提高模型的性能。如下面代码的意思

if self.use_fpn:
            self.fpn_down = FPN(outs, fpn_size, proj_type, upsample_type)
            self.build_fpn_classifier_down(outs, fpn_size, num_classes)
            self.fpn_up = FPN_UP(outs, fpn_size)
            self.build_fpn_classifier_up(outs, fpn_size, num_classes)
  • 1
  • 2
  • 3
  • 4
  • 5

具体地,我们将分类器 k 1 k_1 k1的输出定义为 Y i 1 Y_{i1} Yi1 ,将分类器 k 2 k_2 k2的输出定义为 Y i 2 Y_{i2} Yi2 。通过细化目标函数,我们希望在早期层中学习更多样化、更强的表示,同时允许后面的层专注于更精细的细节。换句话说,高温细化模块使分类器 K 1 K_1 K1能够探索更广阔的领域,从而使分类器 K 2 K_2 K2专注于学习细粒度和判别性的特征。

细化目标函数的计算包括以下步骤:

计算 k 1 k_1 k1的输出 P i 1 P_i1 Pi1
P i 1 = L o g S o f t m a x ( Y i 1 / T e ) ( 10 ) P_{i1}=LogSoftmax(Y_{i1/T_e})(10) Pi1=LogSoftmax(Yi1/Te)(10)
计算 k 2 k_2 k2的输出 P i 2 P_i2 Pi2
P i 2 = S o f t m a x ( Y i 2 / T e ) ( 11 ) P_{i2}=Softmax(Y_{i2}/T_e)(11) Pi2=Softmax(Yi2/Te)(11)
计算细化损失 l o s s r loss_r lossr
l o s s r = P i 2 l o g ( P i 2 P i 1 ) ( 12 ) loss_r=P_{i2}log(\frac{P_{i2}}{P_{i1}})(12) lossr=Pi2log(Pi1Pi2)(12)
其中 T e T_e Te表示训练时期 e e e 的温度。 T e T_e Te的值随着训练时期的增加而减小,遵循定义为的衰减函数:

在这里插入图片描述
具体实现代码如下:

outs = model(datas)

            loss = 0.
            for name in outs:
                
                if "FPN1_" in name:
                    if args.lambda_b0 != 0:
                        aux_name = name.replace("FPN1_", "")
                        gt_score_map = outs[aux_name].detach()
                        thres = torch.Tensor(model.selector.thresholds[aux_name])
                        gt_score_map = suppression(gt_score_map, thres, temperature)
                        logit = F.log_softmax(outs[name] / temperature, dim=-1)
                        loss_b0 = nn.KLDivLoss()(logit, gt_score_map)
                        loss += args.lambda_b0 * loss_b0
                    else:
                        loss_b0 = 0.0

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

HERBS 的总损失可以表示为:

l o s s h e r b s = l o s s b s + λ r l o s s r ( 14 ) loss_{herbs}=loss_{bs}+\lambda_rloss_r(14) lossherbs=lossbs+λrlossr(14)
其中 λ r \lambda_r λr是细化损失的权重,设置为1。HERBS网络的最终输出是九个分类器结果之和的softmax,其中四个来自自上而下的方法,四个来自自上而下的方法自下而上的方法,以及一种来自组合器的方法。

请注意,在 HERBS 网络中,当 i i i等于 k k k 时, W i W_i Wi b i b_i bi 属于分类器 k 2 k_2 k2。我们单独描述它们是因为BS模块和高温细化模块可以单独应用于主干,这是非常灵活的。实验结果表明,两个模块都能提高准确率。当然,当使用整个 HERBS 网络时,模型的能力将带来更好的性能。

在深层网络中,底层和顶层的特征表征是不同的。底层特征主要包含图像的低级信息,例如边缘和纹理,而顶层特征则包含更加抽象和语义化的信息,例如物体的形状和类别。

通过让自上而下的分类器 k 1 k_1 k1学习自下而上的分类器 k 2 k_2 k2 的输出分布,可以使得底层特征更加丰富和多样化,从而提高模型对于输入图像的表示能力。这样做的好处在于,底层特征的丰富性可以为后续的分类任务提供更多的信息,从而使得模型更加准确地区分不同的类别

代码

主函数加载主干模型,并在下面分支应用HERBS

def build_resnet50(pretrained: str = "./resnet50_miil_21k.pth",
                   return_nodes: Union[dict, None] = None,
                   num_selects: Union[dict, None] = None, 
                   img_size: int = 448,
                   use_fpn: bool = True,
                   fpn_size: int = 512,
                   proj_type: str = "Conv",
                   upsample_type: str = "Bilinear",
                   use_selection: bool = True,
                   num_classes: int = 200,
                   use_combiner: bool = True,
                   comb_proj_size: Union[int, None] = None):
    
    import timm
    
    if return_nodes is None:
        return_nodes = {
            'layer1.2.act3': 'layer1',
            'layer2.3.act3': 'layer2',
            'layer3.5.act3': 'layer3',
            'layer4.2.act3': 'layer4',
        }
    if num_selects is None:
        num_selects = {
            'layer1':32,
            'layer2':32,
            'layer3':32,
            'layer4':32
        }
    
    backbone = timm.create_model('resnet50', pretrained=False, num_classes=11221)
    ### original pretrained path "./models/resnet50_miil_21k.pth"
    if pretrained != "":
        backbone = load_model_weights(backbone, pretrained)

    # print(backbone)
    # print(get_graph_node_names(backbone))
    
    return pim_module.PluginMoodel(backbone = backbone,
                                   return_nodes = return_nodes,
                                   img_size = img_size,
                                   use_fpn = use_fpn,
                                   fpn_size = fpn_size,
                                   proj_type = proj_type,
                                   upsample_type = upsample_type,
                                   use_selection = use_selection,
                                   num_classes = num_classes,
                                   num_selects = num_selects, 
                                   use_combiner = num_selects,
                                   comb_proj_size = comb_proj_size)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

HERBS代码如下

class PluginMoodel(nn.Module):
    def __init__(self, 
                 backbone: torch.nn.Module,  # 主干网络,推荐在 ImageNet 或 IG-3.5B-17k 上预训练的模型
                 return_nodes: Union[dict, None],  # 返回节点,用于指定输出的键值对应关系,如果使用 Swin-Transformer,则设置为 None
                 img_size: int,  # 图像尺寸
                 use_fpn: bool,  # 是否使用特征金字塔网络
                 fpn_size: Union[int, None],  # 特征金字塔网络的投影维度
                 proj_type: str,  # 投影类型
                 upsample_type: str,  # 上采样类型
                 use_selection: bool,  # 是否使用选择器
                 num_classes: int,  # 类别数量
                 num_selects: dict,  # 每个选择层的特征数量
                 use_combiner: bool,  # 是否使用组合器
                 comb_proj_size: Union[int, None]  # 组合器的投影尺寸
                 ):
        """
        初始化函数,用于创建 PluginModel 的实例。

        Args:
            backbone: 主干网络,推荐在 ImageNet 或 IG-3.5B-17k 上预训练的模型。
            return_nodes: 返回节点,用于指定输出的键值对应关系,如果使用 Swin-Transformer,则设置为 None。
            img_size: 图像尺寸。
            use_fpn: 是否使用特征金字塔网络。
            fpn_size: 特征金字塔网络的投影维度。
            proj_type: 投影类型。
            upsample_type: 上采样类型。
            use_selection: 是否使用选择器。
            num_classes: 类别数量。
            num_selects: 每个选择层的特征数量。
            use_combiner: 是否使用组合器。
            comb_proj_size: 组合器的投影尺寸。
        """
        super(PluginMoodel, self).__init__()
        
        ### = = = = = Backbone = = = = =
        self.return_nodes = return_nodes
        if return_nodes is not None:
            self.backbone = create_feature_extractor(backbone, return_nodes=return_nodes)
        else:
            self.backbone = backbone
        
        ### 获取隐藏特征的尺寸
        rand_in = torch.randn(1, 3, img_size, img_size)
        outs = self.backbone(rand_in)

        ### 如果仅使用原始主干网络
        if not use_fpn and (not use_selection and not use_combiner):
            for name in outs:
                fs_size = outs[name].size()
                if len(fs_size) == 3:
                    out_size = fs_size.size(-1)
                elif len(fs_size) == 4:
                    out_size = fs_size.size(1)
                else:
                    raise ValueError("The size of output dimension of previous must be 3 or 4.")
            self.classifier = nn.Linear(out_size, num_classes)

        ### = = = = = FPN = = = = =
        self.use_fpn = use_fpn
        if self.use_fpn:
            self.fpn_down = FPN(outs, fpn_size, proj_type, upsample_type)
            self.build_fpn_classifier_down(outs, fpn_size, num_classes)
            self.fpn_up = FPN_UP(outs, fpn_size)
            self.build_fpn_classifier_up(outs, fpn_size, num_classes)

        self.fpn_size = fpn_size

        ### = = = = = Selector = = = = =
        self.use_selection = use_selection
        if self.use_selection:
            w_fpn_size = self.fpn_size if self.use_fpn else None
            self.selector = WeaklySelector(outs, num_classes, num_selects, w_fpn_size)

        ### = = = = = Combiner = = = = =
        self.use_combiner = use_combiner
        if self.use_combiner:
            assert self.use_selection, "Please use selection module before combiner"
            if self.use_fpn:
                gcn_inputs, gcn_proj_size = None, None
            else:
                gcn_inputs, gcn_proj_size = outs, comb_proj_size
            total_num_selects = sum([num_selects[name] for name in num_selects])
            self.combiner = GCNCombiner(total_num_selects, num_classes, gcn_inputs, gcn_proj_size, self.fpn_size)

    def build_fpn_classifier_up(self, inputs: dict, fpn_size: int, num_classes: int):
        """
        构建 FPN 上采样部分的分类器。

        Args:
            inputs: 输入特征。
            fpn_size: FPN 的投影维度。
            num_classes: 类别数量。
        """
        for name in inputs:
            m = nn.Sequential(
                    nn.Conv1d(fpn_size, fpn_size, 1),
                    nn.BatchNorm1d(fpn_size),
                    nn.ReLU(),
                    nn.Conv1d(fpn_size, num_classes, 1)
                )
            self.add_module("fpn_classifier_up_"+name, m)

    def build_fpn_classifier_down(self, inputs: dict, fpn_size: int, num_classes: int):
        """
        构建 FPN 下采样部分的分类器。

        Args:
            inputs: 输入特征。
            fpn_size: FPN 的投影维度。
            num_classes: 类别数量。
        """
        for name in inputs:
            m = nn.Sequential(
                    nn.Conv1d(fpn_size, fpn_size, 1),
                    nn.BatchNorm1d(fpn_size),
                    nn.ReLU(),
                    nn.Conv1d(fpn_size, num_classes, 1)
                )
            self.add_module("fpn_classifier_down_" + name, m)

    def forward_backbone(self, x):
        """
        主干网络的前向传播。

        Args:
            x: 输入数据。

        Returns:
            主干网络的输出特征。
        """
        return self.backbone(x)

    def fpn_predict_down(self, x: dict, logits: dict):
        """
        FPN 下采样部分的特征预测。

        Args:
            x: 输入特征。
            logits: 存储预测结果的字典。
        """
        for name in x:
            if "FPN1_" not in name:
                continue 
            ### 预测每个特征点的结果
            if len(x[name].size()) == 4:
                B, C, H, W = x[name].size()
                logit = x[name].view(B, C, H*W)
            elif len(x[name].size()) == 3:
                logit = x[name].transpose(1, 2).contiguous()
            model_name = name.replace("FPN1_", "")
            logits[name] = getattr(self, "fpn_classifier_down_" + model_name)(logit)
            logits[name] = logits[name].transpose(1, 2).contiguous()

    def fpn_predict_up(self, x: dict, logits: dict):
        """
        FPN 上采样部分的特征预测。

        Args:
            x: 输入特征。
            logits: 存储预测结果的字典。
        """
        for name in x:
            if "FPN1_" in name:
                continue
            ### 预测每个特征点的结果
            if len(x[name].size()) == 4:
                B, C, H, W = x[name].size()
                logit = x[name].view(B, C, H*W)
            elif len(x[name].size()) == 3:
                logit = x[name].transpose(1, 2).contiguous()
            model_name = name.replace("FPN1_", "")
            logits[name] = getattr(self, "fpn_classifier_up_" + model_name)(logit)
            logits[name] = logits[name].transpose(1, 2).contiguous()

    def forward(self, x: torch.Tensor):
        """
        模型的前向传播。

        Args:
            x: 输入数据。

        Returns:
            若使用组合器,则返回组合器的输出;若使用选择器或 FPN,则返回预测结果;否则返回分类器的输出。
        """
        logits = {}

        x = self.forward_backbone(x)

        if self.use_fpn:
            x = self.fpn_down(x)
            self.fpn_predict_down(x, logits)
            x = self.fpn_up(x)
            self.fpn_predict_up(x, logits)

        if self.use_selection:
            selects = self.selector(x, logits)

        if self.use_combiner:
            comb_outs = self.combiner(selects)
            logits['comb_outs'] = comb_outs
            return logits
        
        if self.use_selection or self.fpn:
            return logits

        ### 原始主干网络(仅预测最终选择层)
        for name in x:
            hs = x[name]

        if len(hs.size()) == 4:
            hs = F.adaptive_avg_pool2d(hs, (1, 1))
            hs = hs.flatten(1)
        else:
            hs = hs.mean(1)
        out = self.classifier(hs)
        logits['ori_out'] = logits

        return logits
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/602672
推荐阅读
  

闽ICP备14008679号