当前位置:   article > 正文

【Bert】(十二)简易问答系统--源码解析(bert后处理模型+损失函数)_基于bert的问答系统代码

基于bert的问答系统代码

论文:https://arxiv.org/pdf/1810.04805.pdf

官方代码:GitHub - google-research/bert: TensorFlow code and pre-trained models for BERT

bert后处理模型

在run_squad.py中的create_model函数中,“bert后处理模型”代码为:

  1. final_hidden = model.get_sequence_output()
  2. final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
  3. batch_size = final_hidden_shape[0]
  4. seq_length = final_hidden_shape[1]
  5. hidden_size = final_hidden_shape[2]
  6. output_weights = tf.get_variable(
  7. "cls/squad/output_weights", [2, hidden_size],
  8. initializer=tf.truncated_normal_initializer(stddev=0.02))
  9. output_bias = tf.get_variable(
  10. "cls/squad/output_bias", [2], initializer=tf.zeros_initializer())
  11. final_hidden_matrix = tf.reshape(final_hidden,
  12. [batch_size * seq_length, hidden_size])
  13. logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
  14. logits = tf.nn.bias_add(logits, output_bias)
  15. logits = tf.reshape(logits, [batch_size, seq_length, 2])
  16. logits = tf.transpose(logits, [2, 0, 1])
  17. unstacked_logits = tf.unstack(logits, axis=0)
  18. (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
  19. return (start_logits, end_logits)

 

最终得到的start_logits, end_logits,他们的形状都为【batchsize, seq_length】。

这种处理方式只适合一问一答的情况。

损失函数

  1. def compute_loss(logits, positions):
  2. one_hot_positions = tf.one_hot(
  3. positions, depth=seq_length, dtype=tf.float32)
  4. log_probs = tf.nn.log_softmax(logits, axis=-1)
  5. loss = -tf.reduce_mean(
  6. tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
  7. return loss
  8. start_positions = features["start_positions"]
  9. end_positions = features["end_positions"]
  10. start_loss = compute_loss(start_logits, start_positions)
  11. end_loss = compute_loss(end_logits, end_positions)
  12. total_loss = (start_loss + end_loss) / 2.0

在start_logits中概率最大的认为是起始位置,end_logits中概率最大的认为是终止位置。

根据这样的理念结合交叉熵损失,就能得到上述代码描述的情况。

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/801987
推荐阅读
相关标签
  

闽ICP备14008679号