赞
踩
论文《Distilling the Knowledge in a Neural Network》
* 源码以Github为准
Github链接:GitHub - yeqiwang/KnowledgeDistilling: tensorflow2_knowledge_distilling_example
本文使用fashion_mnist数据集,输入图像大小为28*28,共分为10类。
通过tensoflow加载数据,并对label进行one hot编码。
- import tensorflow as tf
- from tensorflow import keras
- import numpy as np
-
- fashion_mnist = tf.keras.datasets.fashion_mnist
- (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
- train_images = train_images/255
- test_images = test_images/255
- train_labels = tf.one_hot(train_labels, depth=10)
- test_labels = tf.one_hot(test_labels, depth=10)
本文中使用一个4层MLP来作为教师模型。
训练过程中,模型最后使用softmax层来计算损失值。
训练结束后,更改最后的softmax层,以便生成软标签,其中T=2。同时,为了防止误操作,将教师模型冻结。
需要注意的是,虽然更改后教师模型不再进行训练,但仍需要使用compile函数进行配置,否则无法调用predict函数。
- # 构建并训练教师模型
- inputs = keras.layers.Input(shape=(28,28))
- x = keras.layers.Flatten()(inputs)
- x = keras.layers.Dense(128, activation='relu')(x)
- x = keras.layers.Dense(128, activation='relu')(x)
- x = keras.layers.Dense(128, activation='relu')(x)
- x = keras.layers.Dense(10)(x)
- outputs = keras.layers.Softmax()(x)
-
- t_model = keras.Model(inputs, outputs)
- t_model.summary()
-
- callback = [keras.callbacks.EarlyStopping(patience=10 ,restore_best_weights=True)]
- t_model.compile(optimizer='adam',
- loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
- metrics=['accuracy'])
-
- t_model.fit(train_images, train_labels, epochs=500, validation_data=(test_images, test_labels),callbacks=callback)
-
- # 更改教师模型以便后续生成软标签
- x = t_model.get_layer(index=-2).output
- outputs = keras.layers.Softmax()(x/3)
- Teacher_model = keras.Model(t_model.input, outputs)
- Teacher_model.summary()
- Teacher_model.trainable = False
-
- Teacher_model.compile(optimizer='adam',
- loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
- metrics=['accuracy'])

本文使用一个2层MLP作为学生模型。
学生模型构建完成后不进行训练,在后续的蒸馏过程中进行训练。
需要注意的是,学生模型最后一层不加Softmax层。
- inputs = keras.layers.Input(shape=(28,28))
- x = keras.layers.Flatten()(inputs)
- x = keras.layers.Dense(128, activation='relu')(x)
- outputs = keras.layers.Dense(10)(x)
-
- Student_model = keras.Model(inputs, outputs)
- Student_model.summary()
学生模型进行蒸馏时,损失函数包括两部分:
Loss1:学生模型softmax输出值与真实标签的之间的损失(交叉熵);
Loss2:学生模型软化后的softmax输出值(T=2)与教师模型生成的软标签之间的损失(KL散度)。
则,Loss = 0.1*Loss1 + 0.9*Loss2。
本文通过重写Model类来实现。
- class Distilling(keras.Model):
- def __init__(self, student_model, teacher_model, T, alpha):
- super(Distilling, self).__init__()
- self.student_model = student_model
- self.teacher_model = teacher_model
- self.T = T
- self.alpha = alpha
-
- def train_step(self, data):
- x, y = data
- softmax = keras.layers.Softmax()
- kld = keras.losses.KLDivergence()
- with tf.GradientTape() as tape:
- logits = self.student_model(x)
- soft_labels = self.teacher_model(x)
- loss_value1 = self.compiled_loss(y, softmax(logits))
- loss_value2 = kld(soft_labels, softmax(logits/self.T))
- loss_value = self.alpha* loss_value2 + (1-self.alpha) * loss_value1
- grads = tape.gradient(loss_value, self.student_model.trainable_weights)
- self.optimizer.apply_gradients(zip(grads, self.student_model.trainable_weights))
- self.compiled_metrics.update_state(y, softmax(logits))
- return {'sum_loss':loss_value, 'loss1': loss_value1, 'loss2':loss_value2, }
-
- def test_step(self, data):
- x, y = data
- softmax = keras.layers.Softmax()
- logits = self.student_model(x)
- loss_value = self.compiled_loss(y, softmax(logits))
- return {'loss':loss_value}
-
- def call(self, inputs):
- return self.student_model(inputs)
-

蒸馏过程加入早停止机制,监视val_loss。
- distill = Distilling(Student_model, Teacher_model, 2, 0.9)
- distill.compile(optimizer='adam',
- loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False))
-
- callback = [keras.callbacks.EarlyStopping(patience=20, restore_best_weights=True)]
-
- distill.fit(train_images, train_labels, epochs=500, validation_data=(test_images, test_labels), callbacks=callback)
为了验证结果,本文独立训练学生模型(加入Softmax层),与使用知识蒸馏训练的学生模型进行对比。
实验结果如下:
这表明,知识蒸馏方法确实有效。
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。