当前位置:   article > 正文

知识蒸馏 示例代码实现及下载_蒸馏学习代码

蒸馏学习代码

知识蒸馏 代码实现

论文《Distilling the Knowledge in a Neural Network》

* 源码以Github为准

Github链接:GitHub - yeqiwang/KnowledgeDistilling: tensorflow2_knowledge_distilling_example

1. 数据集

本文使用fashion_mnist数据集,输入图像大小为28*28,共分为10类。

通过tensoflow加载数据,并对label进行one hot编码。

  1. import tensorflow as tf
  2. from tensorflow import keras
  3. import numpy as np
  4. fashion_mnist = tf.keras.datasets.fashion_mnist
  5. (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
  6. train_images = train_images/255
  7. test_images = test_images/255
  8. train_labels = tf.one_hot(train_labels, depth=10)
  9. test_labels = tf.one_hot(test_labels, depth=10)

2. 教师模型

本文中使用一个4层MLP来作为教师模型。

训练过程中,模型最后使用softmax层来计算损失值。

训练结束后,更改最后的softmax层,以便生成软标签,其中T=2。同时,为了防止误操作,将教师模型冻结。

需要注意的是,虽然更改后教师模型不再进行训练,但仍需要使用compile函数进行配置,否则无法调用predict函数。

  1. # 构建并训练教师模型
  2. inputs = keras.layers.Input(shape=(28,28))
  3. x = keras.layers.Flatten()(inputs)
  4. x = keras.layers.Dense(128, activation='relu')(x)
  5. x = keras.layers.Dense(128, activation='relu')(x)
  6. x = keras.layers.Dense(128, activation='relu')(x)
  7. x = keras.layers.Dense(10)(x)
  8. outputs = keras.layers.Softmax()(x)
  9. t_model = keras.Model(inputs, outputs)
  10. t_model.summary()
  11. callback = [keras.callbacks.EarlyStopping(patience=10 ,restore_best_weights=True)]
  12. t_model.compile(optimizer='adam',
  13. loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
  14. metrics=['accuracy'])
  15. t_model.fit(train_images, train_labels, epochs=500, validation_data=(test_images, test_labels),callbacks=callback)
  16. # 更改教师模型以便后续生成软标签
  17. x = t_model.get_layer(index=-2).output
  18. outputs = keras.layers.Softmax()(x/3)
  19. Teacher_model = keras.Model(t_model.input, outputs)
  20. Teacher_model.summary()
  21. Teacher_model.trainable = False
  22. Teacher_model.compile(optimizer='adam',
  23. loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
  24. metrics=['accuracy'])

3. 学生模型

本文使用一个2层MLP作为学生模型。

学生模型构建完成后不进行训练,在后续的蒸馏过程中进行训练。

需要注意的是,学生模型最后一层不加Softmax层

  1. inputs = keras.layers.Input(shape=(28,28))
  2. x = keras.layers.Flatten()(inputs)
  3. x = keras.layers.Dense(128, activation='relu')(x)
  4. outputs = keras.layers.Dense(10)(x)
  5. Student_model = keras.Model(inputs, outputs)
  6. Student_model.summary()

4. 知识蒸馏过程

学生模型进行蒸馏时,损失函数包括两部分:

  • Loss1:学生模型softmax输出值与真实标签的之间的损失(交叉熵);

  • Loss2:学生模型软化后的softmax输出值(T=2)与教师模型生成的软标签之间的损失(KL散度)。

则,Loss = 0.1*Loss1 + 0.9*Loss2。

本文通过重写Model类来实现。

  1. class Distilling(keras.Model):
  2. def __init__(self, student_model, teacher_model, T, alpha):
  3. super(Distilling, self).__init__()
  4. self.student_model = student_model
  5. self.teacher_model = teacher_model
  6. self.T = T
  7. self.alpha = alpha
  8. def train_step(self, data):
  9. x, y = data
  10. softmax = keras.layers.Softmax()
  11. kld = keras.losses.KLDivergence()
  12. with tf.GradientTape() as tape:
  13. logits = self.student_model(x)
  14. soft_labels = self.teacher_model(x)
  15. loss_value1 = self.compiled_loss(y, softmax(logits))
  16. loss_value2 = kld(soft_labels, softmax(logits/self.T))
  17. loss_value = self.alpha* loss_value2 + (1-self.alpha) * loss_value1
  18. grads = tape.gradient(loss_value, self.student_model.trainable_weights)
  19. self.optimizer.apply_gradients(zip(grads, self.student_model.trainable_weights))
  20. self.compiled_metrics.update_state(y, softmax(logits))
  21. return {'sum_loss':loss_value, 'loss1': loss_value1, 'loss2':loss_value2, }
  22. def test_step(self, data):
  23. x, y = data
  24. softmax = keras.layers.Softmax()
  25. logits = self.student_model(x)
  26. loss_value = self.compiled_loss(y, softmax(logits))
  27. return {'loss':loss_value}
  28. def call(self, inputs):
  29. return self.student_model(inputs)

蒸馏过程加入早停止机制,监视val_loss。

  1. distill = Distilling(Student_model, Teacher_model, 2, 0.9)
  2. distill.compile(optimizer='adam',
  3. loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False))
  4. callback = [keras.callbacks.EarlyStopping(patience=20, restore_best_weights=True)]
  5. distill.fit(train_images, train_labels, epochs=500, validation_data=(test_images, test_labels), callbacks=callback)

5. 实验结果

为了验证结果,本文独立训练学生模型(加入Softmax层),与使用知识蒸馏训练的学生模型进行对比。

实验结果如下:

  • 教师模型准确度 0.8682
  • 学生模型准确度 0.8365 (知识蒸馏)
  • 学生模型准确度 0.8302 (独立训练)

这表明,知识蒸馏方法确实有效。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/731781
推荐阅读
相关标签
  

闽ICP备14008679号