赞
踩
论文:https://arxiv.org/pdf/1810.04805.pdf
官方代码:GitHub - google-research/bert: TensorFlow code and pre-trained models for BERT
在run_squad.py中的create_model函数中,“bert后处理模型”代码为:
- final_hidden = model.get_sequence_output()
-
- final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
- batch_size = final_hidden_shape[0]
- seq_length = final_hidden_shape[1]
- hidden_size = final_hidden_shape[2]
-
- output_weights = tf.get_variable(
- "cls/squad/output_weights", [2, hidden_size],
- initializer=tf.truncated_normal_initializer(stddev=0.02))
-
- output_bias = tf.get_variable(
- "cls/squad/output_bias", [2], initializer=tf.zeros_initializer())
-
- final_hidden_matrix = tf.reshape(final_hidden,
- [batch_size * seq_length, hidden_size])
- logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
- logits = tf.nn.bias_add(logits, output_bias)
-
- logits = tf.reshape(logits, [batch_size, seq_length, 2])
- logits = tf.transpose(logits, [2, 0, 1])
-
- unstacked_logits = tf.unstack(logits, axis=0)
-
- (start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
-
- return (start_logits, end_logits)
最终得到的start_logits, end_logits,他们的形状都为【batchsize, seq_length】。
这种处理方式只适合一问一答的情况。
- def compute_loss(logits, positions):
- one_hot_positions = tf.one_hot(
- positions, depth=seq_length, dtype=tf.float32)
- log_probs = tf.nn.log_softmax(logits, axis=-1)
- loss = -tf.reduce_mean(
- tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
- return loss
-
- start_positions = features["start_positions"]
- end_positions = features["end_positions"]
-
- start_loss = compute_loss(start_logits, start_positions)
- end_loss = compute_loss(end_logits, end_positions)
-
- total_loss = (start_loss + end_loss) / 2.0
在start_logits中概率最大的认为是起始位置,end_logits中概率最大的认为是终止位置。
根据这样的理念结合交叉熵损失,就能得到上述代码描述的情况。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。