当前位置:   article > 正文

【多模态】19、RegionCLIP | 基于 Region 来实现视觉语言模型预训练

regionclip

在这里插入图片描述


论文: RegionCLIP: Region-based Language-Image Pretraining

代码:https://github.com/microsoft/RegionCLIP

出处:CVPR2022 Oral | 微软 | 张鹏川

一、背景

近期,视觉-语言模型取得了很大的突破,如 CLIP 和 ALIGN,这些模型使用了极大的图文对儿来学习图像和文本的匹配,并且在很多无手工标签的情况下也取得了很好的效果。

为了探索这种思路能否在 region-caption 的情况下起作用,作者基于预训练好的 CLIP 模型构建了一个 R-CNN 形式的目标检测器。

效果和现状:

  • paperwithcode 开放词汇目标检测 COCO(Novel) 排名第 4 名
  • 代码相对较完善,容易入手

在这里插入图片描述

主要思路:

  • 先从输入图像中抠出候选区域
  • 然后使用 CLIP 模型将抠出的区域和 text embedding 进行匹配

在这里插入图片描述

  • 图 1 a-b 展示了在 LVIS 上的结果,当使用 proposal 作为输入时,CLIP 的得分无法代码定位的质量,可以看出不准的框得分为 65%,较准的框得分为 55%。
  • 图 1b 中对比了使用 gt 框作为输入,CLIP 在 LVIS 框上的分类准确率只有 19%
  • 所以,直接将预训练好的 CLIP 拿来用于对 region 的分类不太适合

作者想探索一下这种差别来源于哪里?

  • 首先可以想到,CLIP 模型的训练是使用整个 image 作为输入的,使用的是 image-level 的文本描述来训练的,所以,模型学习到的是整张图的特征
  • 所以这种模型无法将文本概念和图像中的区域联系起来

本文如何解决 image 和 region 之间的差距:

  • 作者通过使用 vision-language 预训练的模型来探索如何学习 region 的表达
  • 主要思想是在预训练过程中,将 image region 和 text token 进行对齐

面临的问题:

  • image-text pairs 中不包含 image region 和 text token 的对齐关系
  • 整张图的文本描述是不全的,也就是图中的有些目标是没有体现在文本描述中的

二、方法

如何预训练:

  • 由于 CLIP 缺少 region 层面的训练,所以 RegionCLIP 构建了一些 region 的伪标签来和 image-text 一起预训练
  • 伪标签如何构建:从网络数据中收集图像描述语句,然后使用 NLP parser 来提取出有效的目标词汇,构建词汇池,然后将词汇池的每个词都填入 prompt 模版(a photo of kite),并且对每个词汇对应的 prompt 模版使用 CLIP 的 text encoder 来得到语义特征,所有的 region concept 都能够使用 semantic embedding 来表示了
  • 构建的 region 伪标签如何和 region 进行对应:使用 CLIP 的 visual encoder V t V_t Vt 来提取每个 region 的 visual feature,计算其和词汇池中的向量的距离,得分最大的向量就被分配给该 region 了,就能得到每个 region 对应的伪标签了,之后用于 pretraining
  • 联合预训练:将 images-concept 和 region-concept 的数据联合进行预训练,训练的时候会同时使用对比学习 loss 和蒸馏 loss
    • region-text 对比学习 loss:计算学生模型学习的 region-text pairs 的相似度
    • 蒸馏 loss:计算教师模型和学生模型得到的 region-text 的 matching score
    • image-text 对比学习 loss:从网络得到的 image-text pairs

如何 zero-shot 推理:

  • 经过预训练的 visual encoder 可以直接用于 region reasoning,也就是对区域内容的识别(注意是没有定位能力的呦!)
  • 例如使用预训练好的 RPN 网络来提取目标,就可以预测目标的类别
  • 而且,作者也将 RPN objectness score 和 category confidence score 做了平均来作为最终得分,这样能够提升效果

如何迁移到目标检测:

  • 使用上面预训练好的 visual encoder 来初始化目标检测器的 visual backbone
  • 使用预训练好的 RPN 来进行目标的定位
  • 使用人工标注的目标检测数据集中的基础类别(base)来训练 visual backbone,我们可以看到这里 visual backbone 是会训练的,那么怎么保证它预训练时候学习到的权重不被忘记呢,这里借用了 focal scaling 的思想,也就是对参与训练的 base 类别设置一定的权重,缓解模型对之前学习到的 object concept 忘记的情况(尤其是 base 类别很少的时候,模型很容易过拟合,对新类的识别产生影响)
  • 然后通过使用预训练得到的 visual backbone 来提取 proposal 的视觉特征并且和目标的类别进行匹配

在这里插入图片描述

本文的目标是学习一个区域级别的视觉-语义空间,能够覆盖足够丰富的目标词汇且用于开放词汇目标检测

  • 假设文本描述 t 能够描述图像 I 中的区域 r
  • 在视觉-语义空间,从 r 中抽取到的 visual region representation 能够和 text representation 很好的匹配上

总体框架图如图 2:

  • V t V_t Vt :CLIP 的 visual encoder, L L L :CLIP 的 language encoder
  • V V V:本文需要训练的 visual encoder,使用 V t V_t Vt 进行初始化,
  • 我们的目标是训练一个 visual encoder V V V 类实现对 image region 的编码,并且将这些编码和 language encoder 输出的语言编码对齐
  • 为了克服缺少大规模 region 描述的问题,如图 2 底部,作者构建了一个目标词汇池,通过将词汇填入 prompt 来构建 region 的描述,并且借助 teacher encoder V t V_t Vt 来将这些描述和使用图像定位网络得到的图像区域进行对齐
  • 通过使用这些创建的 region-text pairs,visual encoder V V V 就需要通过对比学习和词汇整理来学习将这些 pairs 对齐

2.1 Region-based Language-Image Pretraining

1、Visual region representation

可以使用现有的目标定位器(如 RPN)或密集滑动窗口 来进行图像区域的生成

作者使用经过人工标注 bbox 训练过的 RPN 来生成,这里不对 bbox 的类别进行区分

  • 对于一个输入 batch,使用 RPN 产生 N 个 image regions
  • 使用 visual encoder V V V 进行视觉特征抽取,并使用 RoIAlign 来 pooling,且 V V V 的权重是使用 teacher V t V_t Vt 的来进行初始化的

2、Semantic region representation

一个单个的图像通常会包含丰富的语义信息,多个不同类别的目标,且人工标注这么大规模的数据也不太可行

所以,作者首先构建了一个大的词汇池,来尽可能的覆盖所有区域词汇,如图 2 所示,而且建立的词汇池是从文本语料库中解析得来的

有了词汇池后,按照如下的方式来构建每个区域的语义表达:

  • 第一步,将 concept 填入 prompt 模版(a photo of a kite)
  • 第二步,使用预训练的 language encoder L 来得到语义特征表达
  • 第三步,使用语义编码就能表达每个区域词汇的特征表达 { l j } j = 1 , . . . , C \{l_j\}_{j=1,...,C} {lj}j=1,...,C

3、visual-semantic alignment for regions

① 如何对齐 region-text pairs:使用 CLIP 来构建伪标签,即使用 teacher model CLIP 预测的得分最大的 concept 作为该区域的描述

  • 作者借用 teacher visual encoder 来建立 region-text 之间的关系,这里的 text 表示语义编码,区域 r i r_i ri 的 visual representation v i t v_i^t vit 是从 teacher visual encoder V t V_t Vt 中抽取的

  • 然后,计算 v i t v_i^t vit { l j } \{l_j\} {lj} 的匹配得分,得分最高的就和区域进行关联起来,然后就能得到每个区域的伪标签: { v i , l m } \{v_i, l_m\} {vi,lm}

    在这里插入图片描述

② 如何预训练:

  • 同时使用来自网络的 region-text pairs 和 image-text pairs

  • region-text pairs 就是通过 ① 的方法来创建的

  • 拿到上述 region-text pairs { v i , l m } \{v_i, l_m\} {vi,lm},使用对比学习 loss 和蒸馏 loss 来训练 visual decoder,总共包含 3 部分

    在这里插入图片描述

    • region-text 的对比学习 loss 如下, τ \tau τ 是预定义的温度参数, N r i N_{ri} Nri 是 region r i r_i ri 的 negative textual samples,也就是在一个 batch 中和 region r i r_i ri 不匹配但和其他区域匹配的

      在这里插入图片描述

      在这里插入图片描述

    • 除了对比学习 loss 以外,还有考虑每个图像区域的知识蒸馏,蒸馏 loss 如下, q i t q_i^t qit 是从 teacher model 得到的 soft target, q i q_i qi 是 student model 得到的预测

      在这里插入图片描述

    • image-text 的对比学习 loss L c n t r s t − i m g L_{cntrst-img} Lcntrstimg 可以从 region level 扩展而来,也就是特殊情况,即 ① 一个 box 覆盖了整张图,② 文本描述来源于网络,③ negative samples 是从其他图像而来的文本描述

③ 零样本推理

预训练之后,训练得到的 visual encoder 可以直接用于 region reasoning 任务,比如从 RPN 获得区域,从训练的 visual encoder 得到该区域的视觉表达,然后和文本词汇表达进行匹配,得到相似度最高的文本

实验证明使用 RPN score 能够提升 zero-shot 推理的效果,所以作者也使用了 RPN objectness score + category confidence score 的均值来作为最终的得分,用于匹配。

2.2 目标检测的迁移学习

预训练中,本文的 visual encoder 是从 teacher model 提供的 region-text alignment 中学习的,不需要人为一些操作,所以也会有一个噪声,当引入更强的监督信号(如人为标注 label)时,可以进一步 fine-tuning visual encoder,如图 2

如何将预训练网络迁移到目标检测器呢,作者通过初始化目标检测器的 visual backbone 来实现,先使用现有的 RPN 网络来进行目标区域的定位,然后将区域和文本匹配

开放词汇目标检测:

  • 对基础类别,使用类似于 focal loss 的加权权重 ( 1 − p b ) γ (1-p^b) \gamma (1pb)γ p b p^b pb 是预测的概率, γ \gamma γ 是超参数,该加权权重能缓解模型对预训练中的知识的遗忘,尤其是当数据集中有很少的基础类时(如 coco),作者猜测如果基础类别很少,模型可能会对基础类别过拟合,对新类的泛化能力会降低
  • 对背景类别,作者使用固定的 all-zero 编码方式,并且使用预定义的权重

三、效果

3.1 数据集

预训练时,作者使用:

  • 来自于 Conceptual Caption dataset (CC3M) 的 image-text pairs,包括 300 万来自网络的 pairs
  • COCO Caption(COCO Cap),包含 118k images,每个 images 约有 5 个人工标注的 captions
  • 作者从 COCO/CC3M 中抽取了目标词汇,过滤掉了出现频次小于 100 的词汇,得到了 4764/6790 个词汇

为了开放词汇目标检测的迁移学习,作者使用 COCO 数据集和 LVIS 数据集的基础类来训练。

  • COCO:48 个基础类,17 个新类
  • LVIS:866 个基础类,337 个新类

作者使用目标检测标准测评:AP 和 AP50

  • COCO:使用 AP50 测评新类、基础类、所有类
  • LVIS:rare 类也就是 novel 类,即测评新类的 AP (APr)、基础类的 AP (APc/APf)、所有类的 AP (mAP)

3.2 实现细节

1、预训练

  • teacher model 和 student model :都是预训练的 CLIP(ResNet50)
  • RPN:使用 LVIS 的基础类别训练
  • 默认模型:使用 CC3M 数据集,使用从 COCO Cap 解析出来的词汇
  • 优化器: SGD、batch = 96、learning rate = 0.002, maximum iteration = 600k、 100 regions per image.

2、目标检测迁移

  • 使用 detectron2 基于 Faster RCNN [42] with ResNet50-C4 结构作为检测器
  • RPN:使用目标数据集的基础类别来训练
  • SGD:batch=16,initial learning = 0.002,1x schedule
  • focal scaling: γ = 0.5 \gamma=0.5 γ=0.5

3、目标检测零样本推理

  • RPN:使用 LVIS 的基础类别训练得到的 RPN
  • NMS:threshold=0.9
  • τ = 0.01 \tau=0.01 τ=0.01

3.3 结果

RegionCLIP 在开放词汇检测的 novel 类上的效果达到了 39.3AP

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

四、代码

环境安装:

https://github.com/microsoft/RegionCLIP/blob/zero-shot/docs/INSTALL.md

数据准备:

https://github.com/microsoft/RegionCLIP/blob/zero-shot/datasets/README.md

1、zero-shot 测试(使用 gt 框,主要测试的是对 region 的识别能力)

# RN50x4, GT, COCO
python3 ./tools/train_net.py \
--eval-only  \
--num-gpus 1 \
--config-file ./configs/COCO-InstanceSegmentation/CLIP_fast_rcnn_R_50_C4_ovd_zsinf.yaml \
MODEL.WEIGHTS ./pretrained_ckpt/regionclip/regionclip_pretrained-cc_rn50x4.pth \
MODEL.CLIP.TEXT_EMB_PATH ./pretrained_ckpt/concept_emb/coco_65_cls_emb_rn50x4.pth \
MODEL.CLIP.CROP_REGION_TYPE GT \
MODEL.CLIP.MULTIPLY_RPN_SCORE False \
MODEL.CLIP.TEXT_EMB_DIM 640 \
MODEL.RESNETS.DEPTH 200 \
MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION 18 \
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
Evaluation results for bbox: 
|   AP   |  AP50  |  AP75  |  APs   |  APm   |  APl   |
|:------:|:------:|:------:|:------:|:------:|:------:|
| 64.946 | 65.451 | 64.903 | 52.922 | 71.390 | 70.545 |
Per-category bbox AP: 
| category   | AP     | category   | AP     | category     | AP     |
|:-----------|:-------|:-----------|:-------|:-------------|:-------|
| person     | 73.463 | bicycle    | 79.673 | car          | 84.925 |
| motorcycle | 66.633 | airplane   | 97.827 | bus          | 85.852 |
| train      | 89.867 | truck      | 51.154 | boat         | 82.183 |
| bench      | 42.759 | bird       | 73.442 | cat          | 76.972 |
| dog        | 73.316 | horse      | 83.049 | sheep        | 90.901 |
| cow        | 81.264 | elephant   | 93.201 | bear         | 84.405 |
| zebra      | 95.208 | giraffe    | 95.268 | backpack     | 46.059 |
| umbrella   | 71.930 | handbag    | 33.646 | tie          | 80.318 |
| suitcase   | 65.223 | frisbee    | 24.740 | skis         | 38.639 |
| snowboard  | 14.005 | kite       | 59.969 | skateboard   | 53.792 |
| surfboard  | 44.970 | bottle     | 76.109 | cup          | 64.295 |
| fork       | 53.513 | knife      | 19.605 | spoon        | 30.435 |
| bowl       | 52.905 | banana     | 81.173 | apple        | 67.057 |
| sandwich   | 72.380 | orange     | 68.320 | broccoli     | 91.791 |
| carrot     | 80.076 | pizza      | 87.340 | donut        | 78.442 |
| cake       | 73.971 | chair      | 68.974 | couch        | 57.402 |
| bed        | 56.792 | toilet     | 74.647 | tv           | 65.258 |
| laptop     | 66.412 | mouse      | 23.762 | remote       | 19.321 |
| keyboard   | 50.904 | microwave  | 66.310 | oven         | 52.184 |
| toaster    | 29.439 | sink       | 60.021 | refrigerator | 60.273 |
| book       | 84.199 | clock      | 91.148 | vase         | 62.869 |
| scissors   | 40.443 | toothbrush | 59.085 |              |        |
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

2、fine-tuning COCO 48 个基础类别的代码

数据进入模型的位置:RegionCLIP/detectron2/engine/train_loop.py line273

# data 
[{'file_name': 'datasets/coco/train2017/000000526362.jpg', 'height': 403, 'width': 640, 'image_id': 526362, 'image': tensor([[[67, 70, 74,  ..., 56, 56, 56],
         [68, 71, 74,  ..., 57, 57, 57],
         [70, 72, 74,  ..., 59, 59, 59],
         ...,
         [69, 71, 75,  ..., 88, 83, 80],
         [67, 69, 72,  ..., 82, 77, 74],
         [66, 68, 70,  ..., 79, 74, 71]],

        [[49, 50, 52,  ..., 56, 54, 53],
         [51, 52, 54,  ..., 56, 54, 52],
         [54, 55, 56,  ..., 57, 53, 50],
         ...,
         [66, 69, 74,  ..., 80, 74, 71],
         [63, 65, 69,  ..., 74, 68, 65],
         [61, 63, 66,  ..., 70, 64, 61]],

        [[47, 51, 58,  ..., 31, 34, 36],
         [48, 51, 57,  ..., 33, 37, 40],
         [50, 52, 55,  ..., 37, 43, 46],
         ...,
         [61, 64, 70,  ..., 81, 75, 72],
         [59, 61, 66,  ..., 75, 69, 66],
         [57, 59, 63,  ..., 71, 65, 62]]], dtype=torch.uint8), 'instances': Instances(num_instances=11, image_height=736, image_width=1169, fields=[gt_boxes: Boxes(tensor([[125.1378, 305.5587, 361.0018, 481.5047],
        [558.6907, 319.8769, 689.9657, 380.5649],
        [550.3981, 297.8700, 613.7433, 347.7828],
        [681.0156, 162.0844, 896.0568, 673.1569],
        [104.9726, 271.2607, 152.7920, 347.9289],
        [ 53.9933, 272.5574, 119.6947, 468.4102],
        [809.8613, 271.8634, 870.2109, 377.9351],
        [683.9929, 315.2929, 733.6754, 439.1162],
        [456.3666, 289.1585, 561.3574, 395.3945],
        [350.7730, 279.0408, 427.6165, 470.8756],
        [  6.0459, 486.8557, 146.3076, 728.2017]])), gt_classes: tensor([ 2,  2,  2,  0,  0,  0, 15, 15,  5,  0,  2])])}]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

模型:

CLIPFastRCNN(
  (offline_backbone): ResNet(
    (stem): BasicStem(
      (conv1): Conv2d(
        3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
        (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
      )
    )
    (res2): Sequential(
      (0): BottleneckBlock(
        (shortcut): Conv2d(
          64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv1): Conv2d(
          64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
        (conv2): Conv2d(
          64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
        (conv3): Conv2d(
          64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
      )
      (1): BottleneckBlock(
        (conv1): Conv2d(
          256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
        (conv2): Conv2d(
          64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
        (conv3): Conv2d(
          64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
      )
      (2): BottleneckBlock(
        (conv1): Conv2d(
          256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
        (conv2): Conv2d(
          64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
        (conv3): Conv2d(
          64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
      )
    )
    (res3): Sequential(
      (0): BottleneckBlock(
        (shortcut): Conv2d(
          256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False
          (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)
        )
        (conv1): Conv2d(
          256, 128, kernel_size=(1, 1), stride=(2, 2), bias=False
          (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)
        )
        (conv2): Conv2d(
          128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)
        )
        (conv3): Conv2d(
          128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)
        )
      )
      (1): BottleneckBlock(
        (conv1): Conv2d(
          512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)
        )
        (conv2): Conv2d(
          128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)
        )
        (conv3): Conv2d(
          128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)
        )
      )
      (2): BottleneckBlock(
        (conv1): Conv2d(
          512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)
        )
        (conv2): Conv2d(
          128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)
        )
        (conv3): Conv2d(
          128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)
        )
      )
      (3): BottleneckBlock(
        (conv1): Conv2d(
          512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)
        )
        (conv2): Conv2d(
          128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=128, eps=1e-05)
        )
        (conv3): Conv2d(
          128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=512, eps=1e-05)
        )
      )
    )
    (res4): Sequential(
      (0): BottleneckBlock(
        (shortcut): Conv2d(
          512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False
          (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)
        )
        (conv1): Conv2d(
          512, 256, kernel_size=(1, 1), stride=(2, 2), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv2): Conv2d(
          256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv3): Conv2d(
          256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)
        )
      )
      (1): BottleneckBlock(
        (conv1): Conv2d(
          1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv2): Conv2d(
          256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv3): Conv2d(
          256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)
        )
      )
      (2): BottleneckBlock(
        (conv1): Conv2d(
          1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv2): Conv2d(
          256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv3): Conv2d(
          256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)
        )
      )
      (3): BottleneckBlock(
        (conv1): Conv2d(
          1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv2): Conv2d(
          256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv3): Conv2d(
          256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)
        )
      )
      (4): BottleneckBlock(
        (conv1): Conv2d(
          1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv2): Conv2d(
          256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv3): Conv2d(
          256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)
        )
      )
      (5): BottleneckBlock(
        (conv1): Conv2d(
          1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv2): Conv2d(
          256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=256, eps=1e-05)
        )
        (conv3): Conv2d(
          256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False
          (norm): FrozenBatchNorm2d(num_features=1024, eps=1e-05)
        )
      )
    )
  )
  (backbone): ModifiedResNet(
    (conv1): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): FrozenBatchNorm2d(num_features=40, eps=1e-05)
    (conv2): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): FrozenBatchNorm2d(num_features=40, eps=1e-05)
    (conv3): Conv2d(40, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): FrozenBatchNorm2d(num_features=80, eps=1e-05)
    (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (relu): ReLU(inplace=True)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=80, eps=1e-05)
        (conv2): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=80, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(80, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (-1): AvgPool2d(kernel_size=1, stride=1, padding=0)
          (0): Conv2d(80, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        )
      )
      (1): Bottleneck(
        (conv1): Conv2d(320, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=80, eps=1e-05)
        (conv2): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=80, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(80, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): Conv2d(320, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=80, eps=1e-05)
        (conv2): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=80, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(80, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (3): Bottleneck(
        (conv1): Conv2d(320, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=80, eps=1e-05)
        (conv2): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=80, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(80, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
    )
    (layer2): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(320, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (conv3): Conv2d(160, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (-1): AvgPool2d(kernel_size=2, stride=2, padding=0)
          (0): Conv2d(320, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        )
      )
      (1): Bottleneck(
        (conv1): Conv2d(640, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(160, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): Conv2d(640, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(160, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (3): Bottleneck(
        (conv1): Conv2d(640, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(160, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (4): Bottleneck(
        (conv1): Conv2d(640, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(160, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (5): Bottleneck(
        (conv1): Conv2d(640, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (conv2): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=160, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(160, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
    )
    (layer3): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(640, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (conv3): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (-1): AvgPool2d(kernel_size=2, stride=2, padding=0)
          (0): Conv2d(640, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        )
      )
      (1): Bottleneck(
        (conv1): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (3): Bottleneck(
        (conv1): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (4): Bottleneck(
        (conv1): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (5): Bottleneck(
        (conv1): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (6): Bottleneck(
        (conv1): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (7): Bottleneck(
        (conv1): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (8): Bottleneck(
        (conv1): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (9): Bottleneck(
        (conv1): Conv2d(1280, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (conv2): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=320, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=1280, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
    )
    (layer4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
        (conv3): Conv2d(640, 2560, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=2560, eps=1e-05)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (-1): AvgPool2d(kernel_size=2, stride=2, padding=0)
          (0): Conv2d(1280, 2560, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): FrozenBatchNorm2d(num_features=2560, eps=1e-05)
        )
      )
      (1): Bottleneck(
        (conv1): Conv2d(2560, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(640, 2560, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=2560, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (2): Bottleneck(
        (conv1): Conv2d(2560, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(640, 2560, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=2560, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (3): Bottleneck(
        (conv1): Conv2d(2560, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(640, 2560, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=2560, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (4): Bottleneck(
        (conv1): Conv2d(2560, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(640, 2560, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=2560, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
      (5): Bottleneck(
        (conv1): Conv2d(2560, 640, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (conv2): Conv2d(640, 640, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): FrozenBatchNorm2d(num_features=640, eps=1e-05)
        (avgpool): Identity()
        (conv3): Conv2d(640, 2560, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): FrozenBatchNorm2d(num_features=2560, eps=1e-05)
        (relu): ReLU(inplace=True)
      )
    )
    (attnpool): AttentionPool2d(
      (k_proj): Linear(in_features=2560, out_features=2560, bias=True)
      (q_proj): Linear(in_features=2560, out_features=2560, bias=True)
      (v_proj): Linear(in_features=2560, out_features=2560, bias=True)
      (c_proj): Linear(in_features=2560, out_features=640, bias=True)
    )
  )
  (offline_proposal_generator): RPN(
    (rpn_head): StandardRPNHead(
      (conv): Conv2d(
        1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
        (activation): ReLU()
      )
      (objectness_logits): Conv2d(1024, 15, kernel_size=(1, 1), stride=(1, 1))
      (anchor_deltas): Conv2d(1024, 60, kernel_size=(1, 1), stride=(1, 1))
    )
    (anchor_generator): DefaultAnchorGenerator(
      (cell_anchors): BufferList()
    )
  )
  (roi_heads): CLIPRes5ROIHeads(
    (pooler): ROIPooler(
      (level_poolers): ModuleList(
        (0): ROIAlign(output_size=(18, 18), spatial_scale=0.0625, sampling_ratio=0, aligned=True)
      )
    )
    (box_predictor): FastRCNNOutputLayers(
      (cls_score): Linear(in_features=640, out_features=48, bias=False)
      (cls_bg_score): Linear(in_features=640, out_features=1, bias=False)
      (test_cls_score): Linear(in_features=640, out_features=65, bias=False)
      (bbox_pred): Linear(in_features=640, out_features=4, bias=True)
    )
  )
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
  • 415
  • 416
  • 417
  • 418
  • 419
  • 420
  • 421
  • 422
  • 423
  • 424
  • 425
  • 426
  • 427
  • 428
  • 429
  • 430
  • 431
  • 432
  • 433
  • 434
  • 435
  • 436
  • 437
  • 438
  • 439
  • 440
  • 441
  • 442
  • 443
  • 444
  • 445
  • 446
  • 447
  • 448
  • 449
  • 450
  • 451
  • 452
  • 453
  • 454
  • 455
  • 456
  • 457
  • 458
  • 459
  • 460
  • 461
  • 462
  • 463
  • 464
  • 465
  • 466
  • 467
  • 468
  • 469
  • 470
  • 471
  • 472
  • 473
  • 474
  • 475
  • 476
  • 477
  • 478
  • 479
  • 480
  • 481
  • 482
  • 483
  • 484
  • 485
  • 486
  • 487
  • 488
  • 489
  • 490
  • 491
  • 492
  • 493
  • 494
  • 495
  • 496
  • 497
  • 498
  • 499
  • 500
  • 501
  • 502
  • 503
  • 504
  • 505
  • 506
  • 507
  • 508
  • 509
  • 510
  • 511
  • 512
  • 513
  • 514
  • 515
  • 516
  • 517
  • 518
  • 519
  • 520
  • 521
  • 522
  • 523
  • 524
  • 525
  • 526
  • 527
  • 528
  • 529
  • 530
  • 531
  • 532
  • 533
  • 534
  • 535
  • 536
  • 537
  • 538
  • 539
  • 540
  • 第一步:生成 proposals 1000 个

    • 如果使用 GT 的话就直接使用标注的框,如图使用 RPN 的话,就使用现有的 RPN 网络来生成,生成的话就是,先对图像进行归一化,然后使用 offline_backbone 来进行特征提取,不进行最后类别特征的提取,输出的特征为 [1,1024, 46, 74](一般情况提取类别的话通道数要变成类别的数量)
    • 对得到的特征使用现有的 RPN 网络来提取 proposals,先生成 anchors,然后使用 rpn_head 来得到 objectness 和 anchor_deltas,得到每个 proposal 的前景得分和相对 anchor 的偏移,且 RPN 不参与训练的情况下是不用反传 loss 的
  • 第二步:对输入图像使用 backbone (ResNet50)提取特征,得到 [1, 1280, 46, 73],这里的 backbone 是使用 CLIP visual encoder 的权重来初始化的,会参与训练,对输入的 text 使用 CLIP 的 text encoder 来编码,且不参与训练。

  • 第三步:对 RPN 预测的 proposal 分配标签,并剔除不满足 IoU 阈值的 proposals,分配的标准是依据 IoU 来将其分配到对应的 gt 上,gt 的类别就是该 proposal 的类别,分配标签后的 proposal 数量一般会比 1000 少,比如 512 个,然后会包含 box 坐标、objectness 得分、对应的类别 cls、gt box。假设有 48 类,那么背景类别的标签就是 48

  • 第四步:根据 proposal 的 box 坐标来提取特征图中对应位置的特征,并且使用 RoIAlign 的方法,将输出特征图大小统一为 18x18,得到特征 [512, 1280, 18, 18] 的特征,512 表示有 512 个 proposals,然后使用 backbone_res5 对 proposal 的特征进行处理,得到 [512, 2560, 9, 9] 特征

  • 第五步:对得到的 proposal 特征进行 attention,得到 [512, 640] 特征,每个 proposal 特征使用 640 维的向量表示(或者进行取平均来表示每个 proposal 的特征)

  • 第六步:使用 attention 后的特征来得到每个 proposal 的 score 和 regress delta,

    • 先对 [512, 640] 的特征进行 normalize,计算 CLIP embedding 和 图像 embedding 的矩阵相乘结果,作为 cls_score,得到结果 [512, 48],这里的 48 为 COCO 的 48 个基础类别,并且将 background score 设置为全零 [512, 1],cat 后得到最终的 score [512, 49],并且给 score 除以温度参数来作为最终的结果
    • 然后使用线性映射将 proposal_delta 的预测映射到 [512, 4]
  • 第七步:计算 loss,使用 focal loss 来计算分类 loss,即计算每个 proposal 的类别和真实 gt 类别的 loss,前景类别权重为 1 ,背景类别权重为 0.2,使用 L1 loss (或 GIoU loss )来计算回归 loss,回归 loss 只计算前景 proposal 的

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/377252
推荐阅读
相关标签
  

闽ICP备14008679号