赞
踩
Target-aware Transformer (TaT) 是一种改进的 Transformer 模型,专注于目标感知的序列到序列(sequence-to-sequence)建模。通过 TaT 算法对卷积神经网络的中间层进行知识蒸馏取得了目前加算计视觉领域知识蒸馏的 SOTA,这一只是蒸馏方法首先是在2022 CVPR 上的一篇题为《Knowledge Distillation via the Target-aware Transformer》的论文中提出的,本文主要讲解一些这篇论文的复现流程。
代码链接:https://github.com/sihaoevery/TaT
数学符号
F
T
F^T
FT: 教师网络的特征图
F
S
F^S
FS: 学生网络的特征图
Γ
(
⋅
)
\Gamma (\cdot)
Γ(⋅): 按通道方向将特征图reshape成2维矩阵的函数
#! /bin/bash
#NOTE: the argument 'adjust_lr' is set to False
NUM_GPUS=8 #Gpu个数
python -m torch.distributed.launch \
--nproc_per_node=${NUM_GPUS} \
--use_env examples/image_classification.py \
--world_size ${NUM_GPUS} \
--log ./result/ilsvrc2012/tat/resnet18_from_resnet34.txt \
-adjust_lr \
--config configs/sample/ilsvrc2012/single_stage/tat/resnet18_from_resnet34_attn.yaml
这个文件是在Linux系统下的一个脚本文件,在当前路径下只需要输入./train_local.sh就可以运行代码了。GPU个数可以根据自己的硬件配置修改。
知识蒸馏主要体现在代码模块的Loss
class MaskedFM(nn.Module): """ Compute the L2 loss. Don't put any learnable parameters in this py file """ def __init__(self, feature_pairs, heads, **kwargs): super().__init__() self.feature_pairs = feature_pairs self.heads=heads def forward(self, student_io_dict, teacher_io_dict, *args, **kwargs): chsim_loss = 0 for pair_name, pair_config in self.feature_pairs.items(): teacher_outputs = extract_feature_map(teacher_io_dict, pair_config['teacher']) student_outputs = extract_feature_map(student_io_dict, pair_config['student']) factor = pair_config.get('factor', 1) loss = self.batch_loss(student_outputs, teacher_outputs) chsim_loss += factor * loss return chsim_loss def batch_loss(self, f_s, f_t): s, q, v = f_s # student feature, query, value t, k = f_t # teacher feature, key heads = self.heads # multi head, heads*c = d b,c,h,w = v.shape q = rearrange(q,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d) k = rearrange(k,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d) v = rearrange(v,'b (h d) x y -> b h (x y) d', h=heads) #(b,heads,hw,d) sim = einsum('b h i d, b h j d -> b h i j', q,k) #(b,heads,hw,hw) sim = sim.softmax(dim=-1) out = einsum('b h i j, b h j d -> b h i d', sim,v)#(b,heads,hw,d) out = rearrange(out, 'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w) q = rearrange(q, 'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w) k = rearrange(k, 'b h (x y) d ->b (h d) x y',x=h,y=w) #(b,c,h,w) #--------------1. Feature Matching ---------------------------# loss = nn.MSELoss()(out,t) #--------------2. Channel-wise Mean, Min and Max-------------# # Experimental attempt, not used by this work # print('max shape,', torch.max(out,1)[0].size()) # loss = 1/3*(nn.MSELoss()(torch.max(out,1)[0], torch.max(q,1)[0])+\ # nn.MSELoss()(torch.min(out,1)[0], torch.min(q,1)[0])+\ # nn.MSELoss()(out.mean(1),q.mean(1))) # print('loss value', loss) #--------------3. Channel-wise Mean; Global pooling-----------# # Experimental attempt, not used by this work # print('global size',q.mean(1).size()) # loss=1/2*(nn.MSELoss()(out.mean(1),q.mean(1))+\ # nn.MSELoss()(self.globalpooling(out),self.globalpooling(q))) return loss
这里计算的是TaT的Loss函数,也是使用一对多的方式提取蒸馏教师网络中间层知识的目标函数。是对Attention公式下面的三个公式的实现。
而TaT算法的实现,是在/TaT-master/torchdistill/models/custom/special.py中的AttenEmbed类中具体实现的,代码如下:
class AttnEmbed(SpecialModule): """ Embedding functions. Put all your learnable parameters in this py file. """ def __init__(self, embedings, device, device_ids, distributed, teacher_model=None, student_model=None, **kwargs): super().__init__() is_teacher = teacher_model is not None self.is_teacher = is_teacher if not is_teacher: student_model = wrap_if_distributed(student_model, device, device_ids, distributed) self.model = teacher_model if is_teacher else student_model io_path_pairs = list() self.embed_dict = nn.ModuleDict() for embed_key, embed_params in embedings.items(): if is_teacher: logger.info("Using {}, compute the key of teacher".format(self.__class__.__name__)) # embed = Embed(**embed_params) # For ablation study # embed = wrap_if_distributed(embed, device, device_ids, distributed) embed = nn.Identity() # no 3x3 conv else: logger.info("Using {}, compute the query and attention output of student".format(self.__class__.__name__)) if 'query' in embed_key: embed = Embed(**embed_params) embed = wrap_if_distributed(embed, device, device_ids, distributed) # embed = nn.Identity() # no 3x3 conv, for ablation elif 'value' in embed_key: embed = Embed(**embed_params) embed = wrap_if_distributed(embed, device, device_ids, distributed) self.embed_dict[embed_key] = embed io_path_pairs.append((embed_key, embed_params['io'], embed_params['path'])) self.io_path_pairs = io_path_pairs def forward(self, x): if self.is_teacher: with torch.no_grad(): return self.model(x) else: return self.model(x) def post_forward(self, io_dict): for embed_key, io_type, module_path in self.io_path_pairs: self.embed_dict[embed_key](io_dict[module_path][io_type])
首先判断模型是教师模型还是学生模型,
如果存在教师模型 (is_teacher=True),则创建一个 nn.Identity() 模块。
如果不存在教师模型 (is_teacher=False),则根据键名(embed_key)中是否包含 ‘query’ 或 ‘value’,创建一个patch 块。这里的Embed()是将文件中前面定义的类按照函数功能使用,作用是将输入分割成patch块,为特征图后续的attention的计算做准备,类似于Transformer中的Embedding层。
最后,创建了一个 io_path_pairs 列表,其中包含了元组 (embed_key, embed_params[‘io’], embed_params[‘path’])。这个列表存储了与嵌入模块相关的键、输入/输出类型以及模块路径的信息。同时,使用 ModuleDict 存储了嵌入模块,键为 embed_key,值为对应的模块对象。
configs文件夹中包含很多模型的配置文件,都是yaml配置文件,我们在调代码时需要在这个文件夹下找到对应的配置文件,然后在配置文件中更改参数,例如:数据集相关参数、模型训练和测试相关参数等等。
train_local.sh脚本中默认的配置文件路径是:./configs/sample/ilsvrc2012/single_stage/tat/resnet18_from_resnet34_attn.yaml
datasets: ilsvrc2012: name: &dataset_name '' type: 'ImageFolder' root: &root_dir '/path/to/the/imagenet/folder' splits: train: dataset_id: &imagenet_train !join [*dataset_name, '/train'] params: root: !join [*root_dir, '/train'] transform_params: - type: 'RandomResizedCrop' params: size: &input_size [224, 224] - type: 'RandomHorizontalFlip' params: p: 0.5 - &totensor type: 'ToTensor' params: - &normalize type: 'Normalize' params: mean: [0.485, 0.456, 0.406] std: [0.229, 0.224, 0.225] val: dataset_id: &imagenet_val !join [*dataset_name, '/val'] params: root: !join [*root_dir, '/val'] transform_params: - type: 'Resize' params: size: 256 - type: 'CenterCrop' params: size: *input_size - *totensor - *normalize models: teacher_model: name: &teacher_model_name 'resnet34' params: num_classes: 1000 pretrained: True experiment: &teacher_experiment !join ['imagenet', '-', *teacher_model_name] ckpt: !join ['./resource/ckpt/ilsvrc2012/teacher/', *teacher_experiment, '.pt'] student_model: name: &student_model_name 'resnet18' params: num_classes: 1000 pretrained: True experiment: &student_experiment !join ['imagenet', '-', *student_model_name, '_from_', *teacher_model_name] ckpt: !join ['./model/', *student_experiment, '.pt'] train: log_freq: 200 num_epochs: 100 train_data_loader: dataset_id: *imagenet_train random_sample: True batch_size: 256 #per gpu num_workers: 10 cache_output: val_data_loader: dataset_id: *imagenet_val random_sample: False batch_size: 32 num_workers: 2 teacher: sequential: [] special: type: 'AttnEmbed' params: embedings: key: io: 'output' path: 'model.layer4' in_channels: 512 out_channels: 512 forward_hook: input: [] output: ['model.layer4','embed_dict.key'] wrapper: #'DistributedDataParallel' requires_grad: True frozen_modules: ['model'] # The ResNet student: adaptations: sequential: [] special: type: 'AttnEmbed' params: embedings: query: io: 'output' path: 'model.layer4' in_channels: 512 out_channels: 512 value: io: 'output' path: 'model.layer4' in_channels: 512 out_channels: 512 forward_hook: input: [] output: ['model.layer4', 'embed_dict.query', 'embed_dict.value'] wrapper: requires_grad: True frozen_modules: [] apex: requires: False opt_level: '01' optimizer: type: 'AdamW' params: lr: 0.0016 #momentum: 0.9 weight_decay: 0.0005 scheduler: type: 'MultiStepLR' params: milestones: [30, 60, 90] gamma: 0.1 criterion: type: 'GeneralizedCustomLoss' org_term: criterion: type: 'KDLoss' params: temperature: 1.0 alpha: 0.1 beta: 0.0 reduction: 'batchmean' factor: 1.0 sub_terms: cs: criterion: type: 'MaskedFM' params: feature_pairs: pair1: teacher: io: 'output' path: ['model.layer4', 'embed_dict.key'] student: io: 'output' path: ['model.layer4', 'embed_dict.query','embed_dict.value'] factor: 1 heads: 1 factor: 0.2 test: test_data_loader: dataset_id: *imagenet_val random_sample: False batch_size: 1 num_workers: 16
复现这里的代码,只需要将yaml文件中dataset 参数配置中的这一行: root: &root_dir ‘/path/to/the/imagenet/folder’ 中的路径地址改为你的ImageNet数据集的路径地址就可以加载数据集训练了。
import os
import sys
sys.path.append("torchdistill的绝对路径")
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。