赞
踩
知乎-mmrazor-模型蒸馏
知识蒸馏算法往往分为 reponse-based基于响应、feature-based基于特征 和 relation-based基于关系三类。
也可为 data-free KD、online KD、self KD(可视为一种特殊的 online KD)和比较经典的 offline KD
mmrazor可以使用不同架构的student和teacher模型
mmseg的模型,在deconde_head.conv_seg后拿到的特征图为logist的特征图,其内的元素都为小数,而非预测的0/1
cwd算法是一种什么样的蒸馏算法呢是data-free的吗,还是online的呢,还是offline的呢
适用于mmseg的cwd模型蒸馏配置文件(default_hook中的Student_CheckpointHook是自定义的hook,继承自mmegin中的CheekpointHook)
cwd算法:首先使用softmax归一化方法将每个通道的feature map转换成一个分布,然后最小化两个网络对应通道的Kullback Leibler (KL)散度。通过这样做,我们的方法着重于模拟网络间通道的软分布。特别的是,KL的差异使学习能够更多地关注通道图中最突出的区域,大概对应于语义分割最有用的信号
_base_ = [ 'mmseg::_base_/datasets/pascal_voc12.py', 'mmseg::_base_/schedules/schedule_160k.py', 'mmseg::_base_/default_runtime.py' ] # 模型的optim_wrapper,学习率和学习策略将来自于继承的schedule_160k,如果不改的话 # wandb的可视化设置在mmseg的default_runtime,也继承自mmseg # schedule_160k.py中的自动保存权重的部分 default_hooks = dict(_delete_=True, timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False), param_scheduler=dict(type='ParamSchedulerHook'), # 使用了自定义的Student_CheckpointHook checkpoint=dict(type='Student_CheckpointHook', by_epoch=False, interval=-1, max_keep_ckpts=2, save_best=['mDice', 'mIoU']), # checkpoint中,interval=-1则不会保存least.pth sampler_seed=dict(type='DistSamplerSeedHook'), visualization=dict(type='SegVisualizationHook')) teacher_ckpt = '/root/autodl-tmp/all_workdir/mmseg_work_dir/baseline-convnext-tiny-upernet-rotate/best_mDice_iter_6800.pth' # noqa: E501 teacher_cfg_path = 'mmseg::all_changed/baseline-convnext-tiny_upernet-rotate.py' # noqa: E501 student_cfg_path = 'mmseg::all_changed/pspnet_r18-d8_b16-160k_voc-material-512x512.py' # noqa: E501 model = dict( _scope_='mmrazor', type='SingleTeacherDistill', architecture=dict(cfg_path=student_cfg_path, pretrained=False), teacher=dict(cfg_path=teacher_cfg_path, pretrained=False), teacher_ckpt=teacher_ckpt, distiller=dict( type='ConfigurableDistiller', distill_losses=dict( loss_cwd=dict(type='ChannelWiseDivergence', tau=1, loss_weight=5)), student_recorders=dict( logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')), teacher_recorders=dict( logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')), connectors=dict( loss_conv_stu=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=1, stride=1, padding=0, norm_cfg=dict(type='BN')), loss_conv_tea=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=3, stride=2, padding=1, padding_mode='circular', norm_cfg=dict(type='BN'))), loss_forward_mappings=dict( loss_cwd=dict( preds_S=dict(from_student=True, recorder='logits', connector='loss_conv_stu'), # 含义:从student_recorders(from_student=True)中读取名为logits的数据 # 加上connnecor字段后,表示从student_recorders中读取名为logists的数据,而后将数据通过名为loss_conv_stu的连接器 preds_T=dict(from_student=False, recorder='logits', connector='loss_conv_tea'))))) # 从teacher_recorders中读取名为logits的数据,而后将数据通过名为loss_conv_tea的连接器 # 而无论是loss_cwd、logits、loss_conv_stu、loss_conv_tea都是自定的名称 find_unused_parameters = True train_cfg = dict( type='IterBasedTrainLoop', max_iters=160000, val_interval=200) val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop') train_dataloader = dict(batch_size=16) # 更改batch_size,否则会继承到pascal_voc12.py中的设置 # 这个16会作为teacher模型的推理batch和student模型的训练batch work_dir = '/root/autodl-tmp/all_workdir/mmrazor_wokdir/distill/convnext-tiny-upernet_to_pspnet-r18'
总结:在语义分割中的cwd算法,可以看作是基于响应的KD,也可以看作是基于特征的KD,因为在传统的cwd算法中,使用的是在通过softmax之前的位置作为蒸馏位点,输出对应的特征图,去计算损失。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。