当前位置:   article > 正文

Target-aware Transformer 知识蒸馏代码复现_crcd蒸馏代码复现

crcd蒸馏代码复现


前言

Target-aware Transformer (TaT) 是一种改进的 Transformer 模型,专注于目标感知的序列到序列(sequence-to-sequence)建模。通过 TaT 算法对卷积神经网络的中间层进行知识蒸馏取得了目前加算计视觉领域知识蒸馏的 SOTA,这一只是蒸馏方法首先是在2022 CVPR 上的一篇题为《Knowledge Distillation via the Target-aware Transformer》的论文中提出的,本文主要讲解一些这篇论文的复现流程。
代码链接:https://github.com/sihaoevery/TaT


数学符号
F T F^T FT: 教师网络的特征图
F S F^S FS: 学生网络的特征图
Γ ( ⋅ ) \Gamma (\cdot) Γ(): 按通道方向将特征图reshape成2维矩阵的函数

一、TaT知识蒸馏算法回顾

  • 首先,对于教师网络的中间特征层,按照通道方向reshape形状成2维矩阵,学生网络的中间层特征图初始化成和教师网络特征图大小相同的初始化矩阵:
    在这里插入图片描述
    公式表示如下:
    f t = Γ ( F T ) ∈ R N × C , f s = Γ ( F S ) ∈ R N × C .
    ft=Γ(FT)RN×C,fs=Γ(FS)RN×C.
    ftfs=Γ(FT)RN×C,=Γ(FS)RN×C.

    其中, N = H × W N = H\times W N=H×W, H H H, W W W 分别表示特征图的高和宽, C C C 表示特征图的个数,也就是通道数。
    将上述二维特征矩阵转置可以表示如下:
    f t ⊤ = [ f 1 t , f 2 t , f 3 t , … , f N t ] , f s ⊤ = [ f 1 s , f 2 s , f 3 s , … , f N s ] .
    ft=[f1t,f2t,f3t,,fNt],fs=[f1s,f2s,f3s,,fNs].
    ftfs=[f1t,f2t,f3t,,fNt],=[f1s,f2s,f3s,,fNs].

    此时,教师网络 f t ⊤ f^{t^{\top}} ft 中学习到的特征,需要指导学生网络 f s ⊤ f^{s \top} fs 中的每一个特征像素点。
  • 接下来,需要教师网络的特征图去指导学生网络重构特征图,使得教师网络的特征图中的知识,可以蒸馏到学生网络的特征图中。在这里TaT的思想是从transformer借鉴过来的,注意力机制公式为:
    Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T d k ) V \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V
    其中的 Q Q Q, K K K, V V V是矩阵, d k d_k dk 是通道个数,这里除以 d k \sqrt{d_k} dk 可以起到降维的作用。可以想象 Q Q Q, K K K, V V V 就是一些 feature map。注意力机制引入了权重分配的概念,即教师特征图在指导每个学生特征图位置上的元素时的关注程度。这些权重告诉 f t f^t ft 哪些部分对于当前的 f s f^s fs 更为重要。那么权重系数可以通过内积计算,然后使用 softmax函数做归一化:
    W i = σ ( ⟨ f 1 s , f i t ⟩ , ⟨ f 2 s , f i t ⟩ , … , ⟨ f N s , f i t ⟩ ) = [ w 1 i , w 2 i , … , w N i ]
    Wi=σ(f1s,fit,f2s,fit,,fNs,fit)=[w1i,w2i,,wNi]
    Wi=σ(f1s,fit,f2s,fit,,fNs,fit)=[w1i,w2i,,wNi]

    注意力计算就是权重系数再乘以学生特征图矩阵的列向量:
    f i s ′ = w 1 i × f 1 s + w 2 i × f 2 s + ⋯ + w N i × f N s f_i^{s^{\prime}}=w_1^i \times f_1^s+w_2^i \times f_2^s+\cdots+w_N^i \times f_N^s fis=w1i×f1s+w2i×f2s++wNi×fNs
    此时,得到的学生特征图 f s ′ f^{s^{'}} fs 再与原来的 教师特征图做损失得到TaT算法的损失函数:
    L T a T = ∥ f s ′ − f t ∥ 2 . \mathcal{L}_{\mathrm{TaT}}=\left\|f^{s^{\prime}}-f^t\right\|_2 . LTaT= fsft 2.
    在这里插入图片描述

二、代码复现

1. train_local.sh文件

#! /bin/bash
#NOTE: the argument 'adjust_lr' is set to False

NUM_GPUS=8   #Gpu个数
python -m torch.distributed.launch \
	--nproc_per_node=${NUM_GPUS} \
	--use_env examples/image_classification.py \
	--world_size ${NUM_GPUS} \
	--log ./result/ilsvrc2012/tat/resnet18_from_resnet34.txt \
	-adjust_lr \
	--config configs/sample/ilsvrc2012/single_stage/tat/resnet18_from_resnet34_attn.yaml
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

这个文件是在Linux系统下的一个脚本文件,在当前路径下只需要输入./train_local.sh就可以运行代码了。GPU个数可以根据自己的硬件配置修改。

2. 算法对应的代码块-以图像分类为例

知识蒸馏主要体现在代码模块的Loss

class MaskedFM(nn.Module):
    """ 
        Compute the L2 loss.
        Don't put any learnable parameters in this py file
    """
    def __init__(self, feature_pairs, heads, **kwargs):
        super().__init__()
        self.feature_pairs = feature_pairs
        self.heads=heads

    def forward(self, student_io_dict, teacher_io_dict, *args, **kwargs):
        chsim_loss = 0
        for pair_name, pair_config in self.feature_pairs.items():
            teacher_outputs = extract_feature_map(teacher_io_dict, pair_config['teacher'])
            student_outputs = extract_feature_map(student_io_dict, pair_config['student'])
            factor = pair_config.get('factor', 1)
            loss = self.batch_loss(student_outputs, teacher_outputs)
            chsim_loss += factor * loss
        return chsim_loss

    def batch_loss(self, f_s, f_t):

        s, q, v = f_s # student feature, query, value
        t, k    = f_t # teacher feature, key
        heads = self.heads # multi head, heads*c = d
        b,c,h,w = v.shape
        
        q = rearrange(q,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)
        k = rearrange(k,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)
        v = rearrange(v,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d)

        sim = einsum('b h i d, b h j d -> b h i j', q,k) #(b,heads,hw,hw)
        sim = sim.softmax(dim=-1)
        out = einsum('b h i j, b h j d -> b h i d', sim,v)#(b,heads,hw,d)

        out = rearrange(out, 'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)
        q   = rearrange(q,   'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)
        k   = rearrange(k,   'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w)

        #--------------1. Feature Matching ---------------------------#
        loss = nn.MSELoss()(out,t) 

        #--------------2. Channel-wise Mean, Min and Max-------------#
        # Experimental attempt, not used by this work
        # print('max shape,', torch.max(out,1)[0].size())
        # loss = 1/3*(nn.MSELoss()(torch.max(out,1)[0], torch.max(q,1)[0])+\
        #             nn.MSELoss()(torch.min(out,1)[0], torch.min(q,1)[0])+\
        #             nn.MSELoss()(out.mean(1),q.mean(1)))
        # print('loss value', loss)

        #--------------3. Channel-wise Mean; Global pooling-----------#
        # Experimental attempt, not used by this work
        # print('global size',q.mean(1).size())
        # loss=1/2*(nn.MSELoss()(out.mean(1),q.mean(1))+\
        #           nn.MSELoss()(self.globalpooling(out),self.globalpooling(q)))

        return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57

这里计算的是TaT的Loss函数,也是使用一对多的方式提取蒸馏教师网络中间层知识的目标函数。是对Attention公式下面的三个公式的实现。
而TaT算法的实现,是在/TaT-master/torchdistill/models/custom/special.py中的AttenEmbed类中具体实现的,代码如下:

class AttnEmbed(SpecialModule):
    """
        Embedding functions.
        Put all your learnable parameters in this py file.
    """
    def __init__(self, embedings, device, device_ids, distributed, 
                 teacher_model=None, student_model=None, **kwargs):
        super().__init__()
        is_teacher = teacher_model is not None
        self.is_teacher = is_teacher
        if not is_teacher:
            student_model = wrap_if_distributed(student_model, device, device_ids, distributed)

        self.model = teacher_model if is_teacher else student_model
        
        io_path_pairs = list()
        self.embed_dict = nn.ModuleDict()
        for embed_key, embed_params in embedings.items():
            if is_teacher:
                logger.info("Using {}, compute the key of teacher".format(self.__class__.__name__))
                # embed = Embed(**embed_params) # For ablation study
                # embed = wrap_if_distributed(embed, device, device_ids, distributed)

                embed = nn.Identity() # no 3x3 conv
            else:
                logger.info("Using {}, compute the query and attention output of student".format(self.__class__.__name__))
                if 'query' in embed_key:
                    embed = Embed(**embed_params)
                    embed = wrap_if_distributed(embed, device, device_ids, distributed)
                    
                    # embed = nn.Identity() # no 3x3 conv, for ablation
                elif 'value' in embed_key:
                    embed = Embed(**embed_params)
                    embed = wrap_if_distributed(embed, device, device_ids, distributed)
            self.embed_dict[embed_key] = embed
            io_path_pairs.append((embed_key, embed_params['io'], embed_params['path']))
        self.io_path_pairs = io_path_pairs

    def forward(self, x):
        if self.is_teacher:
            with torch.no_grad():
                return self.model(x)
        else:
            return self.model(x)

    def post_forward(self, io_dict):
        for embed_key, io_type, module_path in self.io_path_pairs:
            self.embed_dict[embed_key](io_dict[module_path][io_type])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

首先判断模型是教师模型还是学生模型,
如果存在教师模型 (is_teacher=True),则创建一个 nn.Identity() 模块。
如果不存在教师模型 (is_teacher=False),则根据键名(embed_key)中是否包含 ‘query’ 或 ‘value’,创建一个patch 块。这里的Embed()是将文件中前面定义的类按照函数功能使用,作用是将输入分割成patch块,为特征图后续的attention的计算做准备,类似于Transformer中的Embedding层。
最后,创建了一个 io_path_pairs 列表,其中包含了元组 (embed_key, embed_params[‘io’], embed_params[‘path’])。这个列表存储了与嵌入模块相关的键、输入/输出类型以及模块路径的信息。同时,使用 ModuleDict 存储了嵌入模块,键为 embed_key,值为对应的模块对象。

3. 代码文件说明

在这里插入图片描述
configs文件夹中包含很多模型的配置文件,都是yaml配置文件,我们在调代码时需要在这个文件夹下找到对应的配置文件,然后在配置文件中更改参数,例如:数据集相关参数、模型训练和测试相关参数等等。
train_local.sh脚本中默认的配置文件路径是:./configs/sample/ilsvrc2012/single_stage/tat/resnet18_from_resnet34_attn.yaml

datasets:
  ilsvrc2012:
    name: &dataset_name ''
    type: 'ImageFolder'
    root: &root_dir '/path/to/the/imagenet/folder'
    splits:
      train:
        dataset_id: &imagenet_train !join [*dataset_name, '/train']
        params:
          root: !join [*root_dir, '/train']
          transform_params:
            - type: 'RandomResizedCrop'
              params:
                size: &input_size [224, 224]
            - type: 'RandomHorizontalFlip'
              params:
                p: 0.5
            - &totensor
              type: 'ToTensor'
              params:
            - &normalize
              type: 'Normalize'
              params:
                mean: [0.485, 0.456, 0.406]
                std: [0.229, 0.224, 0.225]
      val:
        dataset_id: &imagenet_val !join [*dataset_name, '/val']
        params:
          root: !join [*root_dir, '/val']
          transform_params:
            - type: 'Resize'
              params:
                size: 256
            - type: 'CenterCrop'
              params:
                size: *input_size
            - *totensor
            - *normalize

models:
  teacher_model:
    name: &teacher_model_name 'resnet34'
    params:
      num_classes: 1000
      pretrained: True
    experiment: &teacher_experiment !join ['imagenet', '-', *teacher_model_name]
    ckpt: !join ['./resource/ckpt/ilsvrc2012/teacher/', *teacher_experiment, '.pt']
  student_model:
    name: &student_model_name 'resnet18'
    params:
      num_classes: 1000
      pretrained: True
    experiment: &student_experiment !join ['imagenet', '-', *student_model_name, '_from_', *teacher_model_name]
    ckpt: !join ['./model/', *student_experiment, '.pt']

train:
  log_freq: 200
  num_epochs: 100
  train_data_loader:
    dataset_id: *imagenet_train
    random_sample: True
    batch_size: 256 #per gpu
    num_workers: 10
    cache_output:
  val_data_loader:
    dataset_id: *imagenet_val
    random_sample: False
    batch_size: 32
    num_workers: 2
  teacher:
    sequential: []
    special:
      type: 'AttnEmbed'
      params:
        embedings:
          key:
            io: 'output'
            path: 'model.layer4'
            in_channels: 512
            out_channels: 512
    forward_hook:
      input: []
      output: ['model.layer4','embed_dict.key']
    wrapper: #'DistributedDataParallel'
    requires_grad: True
    frozen_modules: ['model'] # The ResNet
  student:
    adaptations:
    sequential: []
    special:
      type: 'AttnEmbed'
      params:
        embedings:
          query:
            io: 'output'
            path: 'model.layer4'
            in_channels: 512
            out_channels: 512
          value:
            io: 'output'
            path: 'model.layer4'
            in_channels: 512
            out_channels: 512
    forward_hook:
      input: []
      output: ['model.layer4', 'embed_dict.query', 'embed_dict.value']
    wrapper: 
    requires_grad: True
    frozen_modules: []
  apex:
    requires: False
    opt_level: '01'
  optimizer:
    type: 'AdamW'
    params:
      lr: 0.0016
      #momentum: 0.9
      weight_decay: 0.0005
  scheduler:
    type: 'MultiStepLR'
    params:
      milestones: [30, 60, 90]
      gamma: 0.1
  criterion:
    type: 'GeneralizedCustomLoss'
    org_term:
      criterion:
        type: 'KDLoss'
        params:
          temperature: 1.0
          alpha: 0.1
          beta: 0.0
          reduction: 'batchmean'
      factor: 1.0
    sub_terms:
      cs:
        criterion:
          type: 'MaskedFM'
          params:
            feature_pairs:
              pair1:
                teacher:
                  io: 'output'
                  path: ['model.layer4', 'embed_dict.key']
                student:
                  io: 'output'
                  path: ['model.layer4', 'embed_dict.query','embed_dict.value']
                factor: 1
            heads: 1
        factor: 0.2
test:
  test_data_loader:
    dataset_id: *imagenet_val
    random_sample: False
    batch_size: 1
    num_workers: 16
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156

复现这里的代码,只需要将yaml文件中dataset 参数配置中的这一行: root: &root_dir ‘/path/to/the/imagenet/folder’ 中的路径地址改为你的ImageNet数据集的路径地址就可以加载数据集训练了。

总结

  • 代码调试在pycharm中,首先找到主函数,然后逐层分析代码
  • 我遇到的问题是在添加torchdistill文件路径时,使用相对路径会报错,这样你可以尝试添加它的绝对路径
import os
import sys
sys.path.append("torchdistill的绝对路径")
  • 1
  • 2
  • 3
  • 这个代码的体量还是非常庞大的,而且这是大团队开发的,所以代码也是经过很多人之手,所以大家在阅读和复现时,注意先从大体去读,如果出现看不懂的情况就看这个类或者函数的return.
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号