赞
踩
与标准的训练不同之处是loss部分, loss部分除了由传统的标签计算的损失之外, 额外添加了与教师模型计算的损失, 见代码中的KD_loss
。本文中采用了Distilling the Knowledge in a Neural Network中的蒸馏损失。
from torchvision.models.resnet import resnet18, resnet50 import torch from torchvision.transforms import transforms import torchvision.datasets as dst from torch.optim import Adam from torch.nn import CrossEntropyLoss import torch.nn.functional as F import torch.nn as nn resnet18_pretrain_weight = "./weights/resnet18-5c106cde.pth" resnet50_pretrain_weight = "./weights/resnet50_cifar10.pth" img_dir = "/data/cifar10/" def create_data(img_dir): dataset = dst.CIFAR10 mean = (0.4914, 0.4822, 0.4465) std = (0.2470, 0.2435, 0.2616) train_transform = transforms.Compose([ transforms.Pad(4, padding_mode='reflect'), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) test_transform = transforms.Compose([ transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) # define data loader train_loader = torch.utils.data.DataLoader( dataset(root=img_dir, transform=train_transform, train=True, download=True), batch_size=512, shuffle=True, num_workers=4, pin_memory=True) test_loader = torch.utils.data.DataLoader( dataset(root=img_dir, transform=test_transform, train=False, download=True), batch_size=512, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader def load_checkpoint(net, pth_file, exclude_fc=False): if exclude_fc: model_dict = net.state_dict() pretrain_dict = torch.load(pth_file) new_dict = {k: v for k, v in pretrain_dict.items() if 'fc' not in k} model_dict.update(new_dict) net.load_state_dict(model_dict, strict=True) else: pretrain_dict = torch.load(pth_file) net.load_state_dict(pretrain_dict, strict=True) def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res class KD_loss(nn.Module): def __init__(self, T): super(KD_loss, self).__init__() self.T = T def forward(self, out_s, out_t): loss = F.kl_div(F.log_softmax(out_s / self.T, dim=1), F.softmax(out_t / self.T, dim=1), reduction='batchmean') * self.T * self.T return loss def test(net, test_loader): prec1_sum = 0 prec5_sum = 0 net.eval() for i, (img, target) in enumerate(test_loader, start=1): # print(f"batch: {i}") img = img.cuda() target = target.cuda() with torch.no_grad(): out = net(img) prec1, prec5 = accuracy(out, target, topk=(1, 5)) prec1_sum += prec1 prec5_sum += prec5 # print(f"batch: {i}, acc1:{prec1}, acc5:{prec5}") print(f"Acc1:{prec1_sum / (i + 1)}, Acc5: {prec5_sum / (i + 1)}") def train(net_s, net_t, train_loader, test_loader): # opt = Adam(filter(lambda p: p.requires_grad,net.parameters()), lr=0.0001) opt = Adam(net_s.parameters(), lr=0.0001) net_s.train() net_t.eval() for epoch in range(100): for step, batch in enumerate(train_loader): opt.zero_grad() image, target = batch image = image.cuda() target = target.cuda() out_s, out_t = net_s(image), net_t(image) loss_init = CrossEntropyLoss()(out_s, target) loss_kd = KD_loss(T=4)(out_s, out_t) loss = loss_init + loss_kd # prec1, prec5 = accuracy(predict, target, topk=(1, 5)) # print(f"epoch:{epoch}, step:{step}, loss:{loss.item()}, acc1: {prec1},acc5:{prec5}") loss.backward() opt.step() print(f"epoch:{epoch}, loss_init: {loss_init.item()}, loss_kd: {loss_kd.item()}, loss_all:{loss.item()}") test(net_s, test_loader) torch.save(net_s.state_dict(), './resnet18_cifar10_kd.pth') def main(): net_t = resnet50(num_classes=10) net_s = resnet18(num_classes=10) net_t = net_t.cuda() net_s = net_s.cuda() load_checkpoint(net_t, resnet50_pretrain_weight, exclude_fc=False) load_checkpoint(net_s, resnet18_pretrain_weight, exclude_fc=True) # for name, value in net.named_parameters(): # if 'fc' not in name: # value.requires_grad = False train_loader, test_loader = create_data(img_dir) train(net_s, net_t, train_loader, test_loader) # test(net, test_loader) if __name__ == "__main__": main()
teacher model | student model | cifar10 |
---|---|---|
- | resnet18 | 80.34/94.24 |
- | resnet50 | 83.20/94.51 |
resnet50 | resnet18 | 82.25/94.44 |
精度收敛趋势:
通过实验可以发现, 通过蒸馏的方式, resnet18的精度得到了明显的提升。
注: 本文旨在验证知识蒸馏的效果, 因此模型没有采用各种trick以及精细调优, 精度不是SOTA。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。