当前位置:   article > 正文

知识蒸馏(Knowledge Distillation)实例教程_知识蒸馏实例

知识蒸馏实例

1 实验介绍

  1. 首先分别测试了resnet18 和resnet50的在cifar10上的精度结果,预训练权重为torchvision中的resnet18和resnet50的权重, 修改最后的fc层, 在cifar10数据集上进行finetune。
  2. 保持其他条件不变, 用resnet50 作为教师模型训练resnet18, 并测试精度。

2 代码实现

与标准的训练不同之处是loss部分, loss部分除了由传统的标签计算的损失之外, 额外添加了与教师模型计算的损失, 见代码中的KD_loss。本文中采用了Distilling the Knowledge in a Neural Network中的蒸馏损失。

from torchvision.models.resnet import resnet18, resnet50
import torch
from torchvision.transforms import transforms
import torchvision.datasets as dst
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch.nn as nn

resnet18_pretrain_weight = "./weights/resnet18-5c106cde.pth"
resnet50_pretrain_weight = "./weights/resnet50_cifar10.pth"
img_dir = "/data/cifar10/"


def create_data(img_dir):
    dataset = dst.CIFAR10
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2470, 0.2435, 0.2616)
    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(
        dataset(root=img_dir,
                transform=train_transform,
                train=True,
                download=True),
        batch_size=512, shuffle=True, num_workers=4, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        dataset(root=img_dir,
                transform=test_transform,
                train=False,
                download=True),
        batch_size=512, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, test_loader


def load_checkpoint(net, pth_file, exclude_fc=False):
    if exclude_fc:
        model_dict = net.state_dict()
        pretrain_dict = torch.load(pth_file)
        new_dict = {k: v for k, v in pretrain_dict.items() if 'fc' not in k}
        model_dict.update(new_dict)
        net.load_state_dict(model_dict, strict=True)
    else:
        pretrain_dict = torch.load(pth_file)
        net.load_state_dict(pretrain_dict, strict=True)


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class KD_loss(nn.Module):
    def __init__(self, T):
        super(KD_loss, self).__init__()
        self.T = T

    def forward(self, out_s, out_t):
        loss = F.kl_div(F.log_softmax(out_s / self.T, dim=1),
                        F.softmax(out_t / self.T, dim=1),
                        reduction='batchmean') * self.T * self.T

        return loss


def test(net, test_loader):
    prec1_sum = 0
    prec5_sum = 0
    net.eval()
    for i, (img, target) in enumerate(test_loader, start=1):
        # print(f"batch: {i}")
        img = img.cuda()
        target = target.cuda()

        with torch.no_grad():
            out = net(img)
        prec1, prec5 = accuracy(out, target, topk=(1, 5))
        prec1_sum += prec1
        prec5_sum += prec5
        # print(f"batch: {i}, acc1:{prec1}, acc5:{prec5}")
    print(f"Acc1:{prec1_sum / (i + 1)}, Acc5: {prec5_sum / (i + 1)}")


def train(net_s, net_t, train_loader, test_loader):
    # opt = Adam(filter(lambda p: p.requires_grad,net.parameters()), lr=0.0001)
    opt = Adam(net_s.parameters(), lr=0.0001)
    net_s.train()
    net_t.eval()
    for epoch in range(100):
        for step, batch in enumerate(train_loader):
            opt.zero_grad()
            image, target = batch
            image = image.cuda()
            target = target.cuda()
            out_s, out_t = net_s(image), net_t(image)
            loss_init = CrossEntropyLoss()(out_s, target)
            loss_kd = KD_loss(T=4)(out_s, out_t)
            loss = loss_init + loss_kd
            # prec1, prec5 = accuracy(predict, target, topk=(1, 5))
            # print(f"epoch:{epoch}, step:{step}, loss:{loss.item()}, acc1: {prec1},acc5:{prec5}")
            loss.backward()
            opt.step()
        print(f"epoch:{epoch}, loss_init: {loss_init.item()}, loss_kd: {loss_kd.item()}, loss_all:{loss.item()}")
        test(net_s, test_loader)

    torch.save(net_s.state_dict(), './resnet18_cifar10_kd.pth')


def main():
    net_t = resnet50(num_classes=10)
    net_s = resnet18(num_classes=10)
    net_t = net_t.cuda()
    net_s = net_s.cuda()
    load_checkpoint(net_t, resnet50_pretrain_weight, exclude_fc=False)
    load_checkpoint(net_s, resnet18_pretrain_weight, exclude_fc=True)
    # for name, value in net.named_parameters():
    #     if 'fc' not in name:
    #         value.requires_grad = False

    train_loader, test_loader = create_data(img_dir)
    train(net_s, net_t, train_loader, test_loader)
    # test(net, test_loader)


if __name__ == "__main__":
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150

3 实验结果

teacher modelstudent modelcifar10
-resnet1880.34/94.24
-resnet5083.20/94.51
resnet50resnet1882.25/94.44

精度收敛趋势:
在这里插入图片描述
通过实验可以发现, 通过蒸馏的方式, resnet18的精度得到了明显的提升。

注: 本文旨在验证知识蒸馏的效果, 因此模型没有采用各种trick以及精细调优, 精度不是SOTA。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/731760
推荐阅读
相关标签
  

闽ICP备14008679号