当前位置:   article > 正文

【Tensorflow】多分类问题的Precision、Recall和F1计算及Tensorflow实现_机器学习多分类求f1代码

机器学习多分类求f1代码

一、二分类问题的Precision、Recall、F1

网络上关于Precision、Recall和F1的介绍有很多,因此这里只作简单回顾。
在二分类问题中,根据真实类别和预测类别的组合可以分为四中情况,分别是TP(True Positive)、FP(False Positive)、TN(True Negative)、FN(False Negative)。如下图:
在这里插入图片描述
那么Precision表示所有被预测为Positive样本中,真正的Positive比例,即 P r e c i s i o n = T P T P + F P Precision=\frac{TP}{TP+FP} Precision=TP+FPTP

而Recall表示所有Positive样本中,被正确预测出来的比例,即 R e c a l l = T P T P + F N Recall=\frac{TP}{TP+FN} Recall=TP+FNTP

F1则是Precision和Recall的调和平均数,即 F 1 = 2 P R P + R F1 = \frac{2PR}{P+R} F1=P+R2PR

二、多分类问题的Precision、Recall和F1

显然,在多分类中没有了TP、FN、FP、TN的定义了。那么如何计算Precision、Recall和F1呢?下面介绍两种方式。


2.1 方法一:micro average

该方法的核心思想是将多个类别分组,真正关心的类别分为Positive组,其他分为Negative组。例如在命名实体识别任务中,将所有实体的类别分为Positive组,而其他标签’O’单独分为Negative组。下面以多分类的混淆矩阵(confuse matrix)为基础计算Precision、Recall和F1值。
下图是多分类问题的混淆矩阵:
在这里插入图片描述
其中,假设类别1和类别2是Positive组,类别3和类别4是Negative组。
那么真正例(TP)可以定义为Positive组预测正确的样本数,那么下图红色框矩阵的对角线和(黄色单元格的和)即为真正例(TP)
在这里插入图片描述
TP=3+2=5

TP+FP表示所有被预测为Positive的样本数,显然下图红框中所有数的和就是TP+FP
在这里插入图片描述
TP+FP=3+3+1+0+4+2+3+2=18

TP+FN表示所有真实为Positive的样本数,显然下图蓝色框中所有数的和就是TP+FN
在这里插入图片描述
TP+FN=3+4+5+1+3+2+4+0=22

因此

P r e c i s i o n = T P T P + F P = 5 18 Precision = \frac{TP}{TP+FP} = \frac{5}{18} Precision=TP+FPTP=185

R e c a l l = T P T P + F N = 5 22 Recall = \frac{TP}{TP+FN} = \frac{5}{22} Recall=TP+FNTP=225

F 1 = 2 P R P + R = 1 4 F1 = \frac{2PR}{P+R} = \frac{1}{4} F1=P+R2PR=41


2.2 方法二:macro average

该方法的核心思想是,单独计算各个正例的Precison、Recall和F1,然后对所有正例的Precison、Recall和F1求平均值。
这里我们仍然将类别1和类别2作为Positive组,然后按上面的mincro分别计算类别1和类别2的Precison、Recall和F1。

2.2.1 计算类别1的Precision、Recall和F1

TP = 3

TP + FP = 3 + 3 + 1 + 0 = 7

TP + FN = 3 + 4 + 5 + 1 = 13

P r e c i s i o n 1 = T P T P + F P = 3 7 Precision_1 = \frac{TP}{TP+FP} = \frac{3}{7} Precision1=TP+FPTP=73

R e c a l l 1 = T P T P + F N = 3 13 Recall_1 = \frac{TP}{TP+FN} = \frac{3}{13} Recall1=TP+FNTP=133

F 1 1 = 2 P R P + R = 3 10 F1_1 = \frac{2PR}{P+R} = \frac{3}{10} F11=P+R2PR=103

2.2.2 计算类别2的Precision、Recall和F1

TP = 2

TP + FP = 4 + 2 + 3 + 2 = 11

TP + FN = 3 + 2 + 4 + 0 = 9

P r e c i s i o n 2 = T P T P + F P = 2 11 Precision_2 = \frac{TP}{TP+FP} = \frac{2}{11} Precision2=TP+FPTP=112

R e c a l l 2 = T P T P + F N = 2 9 Recall_2 = \frac{TP}{TP+FN} = \frac{2}{9} Recall2=TP+FNTP=92

F 1 2 = 2 P R P + R = 1 5 F1_2 = \frac{2PR}{P+R} = \frac{1}{5} F12=P+R2PR=51

2.2.3 计算平均值

P r e c i s i o n = P r e c i s i o n 1 + P r e c i s i o n 2 2 = 3 / 7 + 2 / 11 2 = 0.305 Precision = \frac{Precision_1+Precision_2}{2} = \frac{3/7 + 2/11}{2} = 0.305 Precision=2Precision1+Precision2=23/7+2/11=0.305

R e c a l l = R e c a l l 1 + R e c a l l 2 2 = 3 / 13 + 2 / 9 2 = 0.226 Recall = \frac{Recall_1+Recall_2}{2} = \frac{3/13 + 2/9}{2} = 0.226 Recall=2Recall1+Recall2=23/13+2/9=0.226

F 1 = F 1 1 + F 1 2 2 = 3 / 10 + 1 / 5 2 = 0.25 F1 = \frac{F1_1+F1_2}{2} = \frac{3/10 + 1/5}{2} = 0.25 F1=2F11+F12=23/10+1/5=0.25

三、Tensorflow实现

import tensorflow as tf
import numpy as np
from tensorflow.python.ops.metrics_impl import _streaming_confusion_matrix
  • 1
  • 2
  • 3

1、计算混淆矩阵(confuse matrix)

labels = tf.convert_to_tensor(np.array([0,0,1,1,2,2]))
predictions = tf.convert_to_tensor(np.array([0,1,2,0,0,1]))
num_classes = 3  # 3个类别0,1,2
# cm是上一个batch的混淆矩阵,op是更新完当前batch的
cm,op = _streaming_confusion_matrix(labels,predictions,num_classes)
  • 1
  • 2
  • 3
  • 4
  • 5
with tf.Session() as sess:
    sess.run(tf.local_variables_initializer()) # 初始化所有的local variables
    print(sess.run(cm))
    print(sess.run(op))
cm = op # 令cm是更新后的混淆矩阵
  • 1
  • 2
  • 3
  • 4
  • 5
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[1. 1. 0.]
 [1. 0. 1.]
 [1. 1. 0.]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2、计算micro average的precision、recall和f1

def safe_div(numerator, denominator):
    """安全除,分母为0时返回0"""
    numerator, denominator = tf.cast(numerator,dtype=tf.float64), tf.cast(denominator,dtype=tf.float64)
    zeros = tf.zeros_like(numerator, dtype=numerator.dtype) # 创建全0Tensor
    denominator_is_zero = tf.equal(denominator, zeros) # 判断denominator是否为零
    return tf.where(denominator_is_zero, zeros, numerator / denominator) # 如果分母为0,则返回零
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
def pr_re_f1(cm, pos_indices):
    num_classes = cm.shape[0]
    neg_indices = [i for i in range(num_classes) if i not in pos_indices]
    cm_mask = np.ones([num_classes, num_classes])
    cm_mask[neg_indices, neg_indices] = 0 # 将负样本预测正确的位置清零零
    diag_sum = tf.reduce_sum(tf.diag_part(cm * cm_mask)) # 正样本预测正确的数量

    cm_mask = np.ones([num_classes, num_classes])
    cm_mask[:, neg_indices] = 0 # 将负样本对应的列清零
    tot_pred = tf.reduce_sum(cm * cm_mask) # 所有被预测为正的样本数量

    cm_mask = np.ones([num_classes, num_classes])
    cm_mask[neg_indices, :] = 0 # 将负样本对应的行清零
    tot_gold = tf.reduce_sum(cm * cm_mask) # 所有正样本的数量

    pr = safe_div(diag_sum, tot_pred)
    re = safe_div(diag_sum, tot_gold)
    f1 = safe_div(2. * pr * re, pr + re)
    
    return pr, re, f1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
pr,re,f1 = pr_re_f1(cm,[0,1])
with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    print(sess.run(pr))
    print(sess.run(re))
    print(sess.run(f1))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
0.2
0.25
0.22222222222222224
  • 1
  • 2
  • 3

3、计算macro average的precision、recall和f1

precisions, recalls, f1s, n_golds = [], [], [], []
pos_indices = [0,1] # 正例
# 计算每个正例的precison,recall和f1
for idx in pos_indices:
    pr, re, f1 = pr_re_f1(cm, [idx])
    precisions.append(pr)
    recalls.append(re)
    f1s.append(f1)
    cm_mask = np.zeros([num_classes, num_classes])
    cm_mask[idx, :] = 1
    n_golds.append(tf.to_float(tf.reduce_sum(cm * cm_mask)))
pr = tf.reduce_mean(precisions)
re = tf.reduce_mean(recalls)
f1 = tf.reduce_mean(f1s)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    print(sess.run(pr))
    print(sess.run(re))
    print(sess.run(f1))
  • 1
  • 2
  • 3
  • 4
  • 5
0.16666666666666666
0.25
0.2
  • 1
  • 2
  • 3
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/142895
推荐阅读
相关标签
  

闽ICP备14008679号