赞
踩
由于bert模型参数很大,在用到生产环境中推理效率难以满足要求,因此经常需要将模型进行压缩。常用的模型压缩的方法有剪枝、蒸馏和量化等方法。比较容易实现的方法为知识蒸馏,下面便介绍如何将bert模型进行蒸馏。
一、知识蒸馏原理
模型蒸馏的目的是用一个小模型去学习大模型的知识,让小模型的效果接近大模型的效果,小模型被称为student,大模型被称为teacher。
知识蒸馏的实现可以根据teacher和student的网络结构的不同设计不同的蒸馏步骤,基本结构如下所示:
损失函数需要计算两个部分,cross entropy loss和mse loss,计算的时候需要注意有soft target和hard target。有两个参数需要定义,通过这两个参数对student和teacher进行拟合。其中一个是温度(T),对logits进行缩放。另一个是权重,用来计算加权损失。hard target就是原始的标注标签。soft target计算公式如下:
加权损失计算如下:
二、将simBert模型蒸馏到simase孪生网络上
蒸馏的步骤示意图可以参考下图:
核心代码如下:
- class Distill_model(tf.keras.Model):
- '''
- 使用dssm进行知识蒸馏
- '''
- def __init__(self,
- config,
- teacher_network,
- vocab_size,
- word_vectors,
- **kwargs):
- self.config = config
- self.vocab_size = vocab_size
- self.word_vectors = word_vectors
- #冻结teacher network的参数
- for layer in teacher_network.layers:
- layer.trainable = False
- #定义学生模型输入
- query = tf.keras.layers.Input(shape=(None,), dtype=tf.int64, name='input_x_ids')
- sim_query = tf.keras.layers.Input(shape=(None,), dtype=tf.int64, name='input_y_ids')
- #定义老师模型输入
- word_ids_a = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_word_ids_a')
- mask_a = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_mask_a')
- type_ids_a = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_type_ids_a')
- word_ids_b = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_word_ids_b')
- mask_b = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_mask_b')
- type_ids_b = tf.keras.layers.Input(shape=(None,), dtype=tf.int32, name='input_type_ids_b')
- input_a = [word_ids_a, mask_a, type_ids_a]
- input_b = [word_ids_b, mask_b, type_ids_b]
- teacher_input = [input_a, input_b]
-
- #teacher_softlabel
- teacher_output = teacher_network(teacher_input)
-
- teacher_soft_label = softmax_t(self.config['t'], teacher_output['logits'])
-
- # embedding层
- # 利用词嵌入矩阵将输入数据转成词向量,shape=[batch_size, seq_len, embedding_size]
- class GatherLayer(tf.keras.layers.Layer):
- def __init__(self, config, vocab_size, word_vectors):
- super(GatherLayer, self).__init__()
- self.config = config
-
- self.vocab_size = vocab_size
- self.word_vectors = word_vectors
-
- def build(self, input_shape):
- with tf.name_scope('embedding'):
- if not self.config['use_word2vec']:
- self.embedding_w = tf.Variable(tf.keras.initializers.glorot_normal()(
- shape=[self.vocab_size, self.config['embedding_size']],
- dtype=tf.float32), trainable=True, name='embedding_w')
- else:
- self.embedding_w = tf.Variable(tf.cast(self.word_vectors, tf.float32), trainable=True,
- name='embedding_w')
- self.build = True
-
- def call(self, inputs, **kwargs):
- return tf.gather(self.embedding_w, inputs, name='embedded_words')
-
- def get_config(self):
- config = super(GatherLayer, self).get_config()
-
- return config
-
-
- shared_net = tf.keras.Sequential([GatherLayer(config, vocab_size, word_vectors),
- shared_lstm_layer(config)])
-
- query_embedding_output = shared_net.predict_step(query)
- sim_query_embedding_output = shared_net.predict_step(sim_query)
-
-
- # 余弦函数计算相似度
- # cos_similarity余弦相似度[batch_size, similarity]
- query_norm = tf.sqrt(tf.reduce_sum(tf.square(query_embedding_output), axis=-1), name='query_norm')
- sim_query_norm = tf.sqrt(tf.reduce_sum(tf.square(sim_query_embedding_output), axis=-1), name='sim_query_norm')
-
- dot = tf.reduce_sum(tf.multiply(query_embedding_output, sim_query_embedding_output), axis=-1)
- cos_similarity = tf.divide(dot, (query_norm * sim_query_norm), name='cos_similarity')
- self.similarity = cos_similarity
-
- # 预测为正例的概率
- cond = (self.similarity > self.config["neg_threshold"])
- pos = tf.where(cond, tf.square(self.similarity), 1 - tf.square(self.similarity))
- neg = tf.where(cond, 1 - tf.square(self.similarity), tf.square(self.similarity))
- predictions = [[neg[i], pos[i]] for i in range(self.config['batch_size'])]
-
- self.logits = self.similarity
- student_soft_label = softmax_t(self.config['t'], self.logits)
- student_hard_label = self.logits
- if self.config['is_training']:
- #训练时候蒸馏
- outputs = dict(student_soft_label=student_soft_label, student_hard_label=student_hard_label, teacher_soft_label=teacher_soft_label, predictions=predictions)
- super(Distill_model, self).__init__(inputs=[query, sim_query, teacher_input], outputs=outputs, **kwargs)
- else:
- #预测时候只加载学生模型
- outputs = dict(predictions=predictions)
- super(Distill_model, self).__init__(inputs=[query, sim_query], outputs=outputs, **kwargs)
其中比较重要的步骤就是先冻结teacher模型的参数使其不参与训练:
#冻结teacher network的参数 for layer in teacher_network.layers: layer.trainable = False
然后在预测阶段只加载student模型:
#预测时候只加载学生模型 outputs = dict(predictions=predictions) super(Distill_model, self).__init__(inputs=[query, sim_query], outputs=outputs, **kwargs)
然后是loss的计算:
- # mse损失计算
- y = tf.reshape(labels, (-1,))
- student_soft_label = model_outputs['student_soft_label']
- teacher_soft_label = model_outputs['teacher_soft_label']
- mse_loss = tf.keras.losses.mean_squared_error(teacher_soft_label, student_soft_label)
-
- #ce损失计算
- similarity = model_outputs['student_hard_label']
- cond = (similarity < self.config["neg_threshold"])
- zeros = tf.zeros_like(similarity, dtype=tf.float32)
- ones = tf.ones_like(similarity, dtype=tf.float32)
- squre_similarity = tf.square(similarity)
- neg_similarity = tf.where(cond, squre_similarity, zeros)
-
- pos_loss = y * (tf.square(ones - similarity) / 4)
- neg_loss = (ones - y) * neg_similarity
- ce_loss = pos_loss+neg_loss
- losses = self.config['alpha']*mse_loss + (1-self.config['alpha'])*ce_loss
- loss = tf.reduce_mean(losses)
三、总结
知识蒸馏作为一个模型压缩的方法,优点还是很多的,实现起来方便,也可以在样本数量少的情况下使用。
参考文章:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。