赞
踩
弄清楚需要评估哪些指标(metrics
)是深度学习的关键。有各种指标,我们就可以评估ML算法的性能。
一般来说,指标(metrics
)的目的是监控和量化训练过程。在一些技术中,如学习率调度learning-rate scheduling
或提前停止early stopping
,指标是用来调度和控制的关键。虽然也可以在这里使用损失loss
,但指标是首选,因为它们能更好地代表训练目标。
与损失相反,指标不需要是可微的(事实上很多都不是),但其中一些是可微的。如果指标本身是可微的,并且它是基于纯PyTorch
实现,那么它也跟损失一样可以用来进行反向传播。
TorchMetrics
对80
多个PyTorch
指标进行了代码实现,且其提供了一个易于使用的API
来创建自定义指标。对于这些已实现的指标,如准确率Accuracy
、召回率Recall
、精确度Precision
、AUROC
、RMSE
、R²
等,可以开箱即用;对于尚未实现的指标,也可以轻松创建自定义指标。主要特点有:
batch
之间自动累积使用pip
:
pip install torchmetrics |
conda
:
1 | conda install -c conda-forge torchmetrics |
与torch.nn
类似,大多数指标都有一个基于类的版本和一个基于函数的版本。
函数版本的指标实现了计算每个度量所需的基本操作。它们是简单的python
函数,接收torch.tensors
作为输入,然后返回torch.tensor
类型的相对应的指标。
一个简单的示例如下:
1 | import torch |
几乎所有的函数版本的指标都有一个相应的基于类的版本,该版本在实际代码中调用对应的函数版本。基于类的指标的特点是具有一个或多个内部状态(类似于PyTorch
模块的参数),使其能够提供额外的功能:
一个示例如下:
1 | import torch |
epoch
之间被重置,并且不应该在训练、验证和测试之间混淆。因此,强烈建议按不同的模式重新初始化指标,如下例所示:
1 | from torchmetrics.classification import Accuracy |
如果想使用一个尚不支持的指标,可以使用TorchMetrics
的API
来实现自定义指标,只需将torchmetrics.Metric
子类化并实现以下方法:
__init__
方法,在这里为每一个指标计算所需的内部状态调用self.add_state
;update
方法,在这里进行更新指标状态所需的逻辑;compute
方法,在这里进行最终的指标计算。以均方根误差(RMSE, Root mean squared error)为例,来看怎样自定义指标。
均方根误差的计算公式为:
为了正确计算RMSE
,我们需要两个指标状态:sum_squared_error
来跟踪目标y^和预测y之间的平方误差;n_observations
来统计我们进行了多少次观测。
1 | from torchmetrics.metric import Metric |
关于实现自定义指标的实际例子和更多信息,看这个页面。
TorchMetrics
支持大多数Python
内置的算术、逻辑和位操作的运算符。
比如:
1 | first_metric = MyFirstMetric() |
a
是指标,
b
可以是指标、张量、整数或浮点数):
a
+
b
)a
&
b
)a
==
b
)floor division
(a
//
b
)a
>=
b
)a
>
b
)a
<=
b
)a
<
b
)a
@
b
)Modulo
,即取余)(a
%
b
)a
*
b
)a
!=
b
)a
|
b
)a
**
b
)a
-
b
)a
/
b
)a
^
b
)abs(a)
)~a
)neg(a)
)pos(a)
)a[0]
)在很多情况下,用多个指标来评估模型的输出是很有好处的。在这种情况下,MetricCollection
类可能会派上用场。它接受一连串的指标,并将这些指标包装成一个可调用的指标类,其接口与任一单一指标相同。
比如:
1 | from torchmetrics import MetricCollection, Accuracy, Precision, Recall |
1 | {'Accuracy': tensor(0.1250), |
MetricCollection
对象的另一个好处是,它将自动尝试通过寻找共享相同基础指标状态的指标组来减少所需的计算。如果找到了这样的指标组,实际上只有其中一个指标被更新,而更新的状态将被广播给组内的其他指标。在上面的例子中,与禁用该功能相比,这将导致计算成本降低2-3倍。然而,这种速度的提高伴随着前期的固定成本,即在第一次更新后必须确定状态组。这个开销可能会大大高于在很低的步数(大约100步)下获得的速度提升,但仍然会导致超过这个步数的整体速度提升。如果事先知道分组,也可以手动设置,以避免动态搜索的这种额外成本。关于这个主题的更多信息,请看该类文档中的
compute_groups
参数。
如果在指标计算中涉及的所有计算都是可微的,那么该指标就支持反向传播。所有的类形式的指标都有一个属性is_differentiable
,它指明该指标是否是可微的。
然而,请注意,一旦缓存的状态从计算图中分离出来,它就不能被反向传播。如果不分离的话就意味着每次更新调用都要存储计算图,这可能会导致内存不足的错误。具体到实际操作时,意味着:
1 | MyMetric.is_differentiable # returns True if metric is differentiable |
如果想直接优化一个指标,它需要支持反向传播(见上节)。然而,如果只是想对使用的指标进行超参数调整,此时如果不确定该指标应该被最大化还是最小化,那么可以参考指标类的higher_is_better
属性:
1 | # returns True because accuracy is optimal when it is maximized |
均方误差MSE
,即mean squared error
,计算公式为:
其中,y是目标值的张量,而y^是预测值的张量。
示例代码:
1 | import torch |
1 | tensor(0.2500) |
均方对数误差MSLE
,即mean squared logarithmic error
,计算公式为:
示例代码:
1 | from torchmetrics import MeanSquaredLogError |
1 | tensor(0.0397) |
平均绝对误差MAE
,即Mean Absolute Error
,计算公式为:
示例代码:
1 | import torch |
1 | tensor(0.5000) |
平均绝对百分比误差MAPE
,即Mean Absolute Percentage Error
,计算公式为:
示例代码:
1 | from torchmetrics import MeanAbsolutePercentageError |
加权平均绝对百分比误差WMAPE
,即Weighted Mean Absolute Percentage Error
,计算公式为:
其与MAPE
的区别可以参考这篇文章。
示例代码:
1 | from torchmetrics import WeightedMeanAbsolutePercentageError |
对称平均绝对百分比误差SMAPE
,即symmetric mean absolute percentage error
,计算公式为:
示例代码:
1 | from torchmetrics import SymmetricMeanAbsolutePercentageError |
余弦相似度,即Cosine Similarity
,其含义可以参考其维基百科:
余弦相似性通过测量两个向量的夹角的余弦值来度量它们之间的相似性。0度角的余弦值是1,而其他任何角度的余弦值都不大于1;并且其最小值是-1。从而两个向量之间的角度的余弦值确定两个向量是否大致指向相同的方向。两个向量有相同的指向时,余弦相似度的值为1;两个向量夹角为90°时,余弦相似度的值为0;两个向量指向完全相反的方向时,余弦相似度的值为-1。这结果是与向量的长度无关的,仅仅与向量的指向方向相关。余弦相似度通常用于正空间,因此给出的值为0到1之间。
注意这上下界对任何维度的向量空间中都适用,而且余弦相似性最常用于高维正空间。例如在信息检索中,每个词项被赋予不同的维度,而一个文档由一个向量表示,其各个维度上的值对应于该词项在文档中出现的频率。余弦相似度因此可以给出两篇文档在其主题方面的相似度。
另外,它通常用于文本挖掘中的文件比较。此外,在数据挖掘领域中,会用到它来度量集群内部的凝聚力。
计算公式为:
cossim(x,y)=x⋅y||x||⋅||y||=∑i=1nxiyi∑i=1nxi2∑i=1nyi2具体计算过程可以参考该文章。
示例代码:
1 | from torchmetrics import CosineSimilarity |
可解释方差,即explained variance
,解释可参考维基百科,计算公式为:
示例代码为:
1 | from torchmetrics import ExplainedVariance |
KL散度,即KL divergence
,解释可见这里,计算公式为:
示例代码为:
1 | from torchmetrics import KLDivergence |
Tweedie偏差分数,即Tweedie Deviance Score
,可参考这里的解释,计算公式为:
示例代码为:
1 | from torchmetrics import TweedieDevianceScore |
Pearson相关性系数,即Pearson Correlation Coefficient
,用于度量两组数据的变量X和Y之间的线性相关的程度,具体解释见这里,计算公式为:
示例代码为:
1 | from torchmetrics import PearsonCorrCoef |
Spearman相关性系数,即Spearman's rank correlation coefficient
,斯皮尔曼相关系数被定义成等级变量之间的皮尔逊相关系数,具体解释见这里,计算公式为:
示例代码为:
1 | from torchmetrics import SpearmanCorrCoef |
决定系数,即R2、Coefficient of determination
,在统计学中用于度量因变量的变异中可由自变量解释部分所占的比例,以此来判断回归模型的解释力。具体解释见这里。计算公式为:
假设一数据集包括y1,…,yn共n个观察值,相对应的模型预测值分别为f1,…,fn。定义残差ei=yi−fi,平均观察值为:
y―=1n∑i=1nyi于是得到总平方和为:
SStot=∑i(yi−y¯)2残差平方和为:
SSres=∑i(yi−y^i)2示例代码为:
1 | from torchmetrics import R2Score |
查看用于分类问题的各种指标之前,先看一下分类问题中指标计算时所需要的输入(包括预测值predictions
和目标值targets
)的形状和数据类型,其中N
是批处理大小,C
是类别数目。
一些背景资料:
Logit
What does the logit value actually mean?
[原创] 用人话解释机器学习中的Logistic Regression(逻辑回归)
【机器学习】逻辑回归(非常详细)
二分类、多分类、多标签分类的基础、原理、算法和工具
多分类模型Accuracy, Precision, Recall和F1-score的超级无敌深入探讨
Type | preds shape | preds dtype | target shape | target dtype |
---|---|---|---|---|
二分类 | (N,) | float | (N,) | 二值,即0或1 |
多分类 | (N,) | int | (N,) | int |
带概率p 或对数几率logit (logit=lnp1−p)的多分类 | (N,C) | float | (N,) | int |
多标签 | (N,…) | float | (N,…) | 二值 |
多维多分类 | (N,…) | int | (N,…) | int |
带概率p 或对数几率logit 的多维多分类 | (N,C,…) | float | (N,…) | int |
以下是一些例子:
1 | # Binary inputs |
multiclass
参数。
StatScores
指标为例看一下怎样使用这个参数。
1 | from torchmetrics.functional import stat_scores |
1 | stat_scores(preds, target, reduce='macro', num_classes=2) |
multiclass=False
来使其处理成二分类问题。
1 | stat_scores(preds, target, reduce='macro', num_classes=1, multiclass=False) |
float
的效果是相同的:
1 | stat_scores(preds.float(), target, reduce='macro', num_classes=1) |
float
),但实际想处理成当前只有2类的多分类问题:
1 | preds = torch.tensor([0.2, 0.7, 0.3]) |
multiclass=True
来实现正确的效果:
1 | stat_scores(preds, target, reduce='macro', num_classes=1) |
混淆矩阵,即Confusion Matrix
,矩阵的每一列代表一个类的实例预测,而每一行表示一个实际的类的实例。之所以如此命名,是因为通过这个矩阵可以方便地看出机器是否将两个不同的类混淆了(比如说把一个类错当成了另一个)。 具体解释见维基百科(中文版示例有些小错误)。
示例代码有:
(1)二分类:
1 | from torchmetrics import ConfusionMatrix |
1 | target = torch.tensor([2, 1, 0, 0]) |
1 | target = torch.tensor([[0, 1, 0], [1, 0, 1]]) |
准确率,即Accuracy
,表明了分类正确的概率,计算公式为:
如果用混淆矩阵中的数据来表达就是:
Accuracy=TP+TNTP+TN+FP+FN对于带概率或对数几率的多分类和多维多分类数据,参数top_k
可以将该指标泛化为为Top-K
准确度指标:对于每个样本,考虑前K个概率或对数几率最高的类别来判断是否找到了正确的标签。
对于多标签和多维多分类输入,该指标默认计算 “全局 “准确度,即单独计算所有标签或子样本。这可以通过设置subset_accuracy=True
来改变为子集准确性(这需要样本中的所有标签或子样本都被正确预测)。
示例代码有:
(1)二分类:
1 | import torch |
1 | target = torch.tensor([0, 1, 2]) |
精度,即Precision
,计算公式为:
示例代码为:
1 | from torchmetrics import Precision |
AUC
,即Area Under the Curve (AUC)
,torchmetrics提供了使用梯形公式trapezoidal rule
计算某条曲线下的面积的方法,计算公式为:
注意,在离散的点形成的向量上进行梯形公式计算,其实际是每两点之间就计算一次,详见torch.trapezoid
函数。
示例代码为:
1 | from torchmetrics.functional import auc |
ROC
,即Receiver Operating Characteristic
,接收者操作特征曲线,是一种坐标图式的分析工具,具体解释可参见维基百科。ROC
空间将伪阳性率(FPR
、在所有实际为阴性的样本中,被错误地判断为阳性之比率)定义为 X
轴,真阳性率(TPR
、在所有实际为阳性的样本中,被正确地判断为阳性之比率)定义为 Y
轴。
将同一模型每个阈值 的 (FPR, TPR) 坐标都画在ROC空间里,就成为特定模型的ROC曲线。
示例代码有:
(1)二分类:
1 | from torchmetrics import ROC |
1 | pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], |
1 | pred = torch.tensor([[0.8191, 0.3680, 0.1138], |
AUC ROC
,即Area under the Curve of ROC
,即ROC
曲线下方的面积,具体解释可参见维基百科,简单说:AUC
值越大的分类器,正确率越高。
示例代码有:
(1)二分类:
1 | from torchmetrics import AUROC |
1 | preds = torch.tensor([[0.90, 0.05, 0.05], |
召回率,即Recall
,计算公式为:
示例代码有:
1 | from torchmetrics import Recall |
F1分数,即F1 score
,兼顾了分类模型的精确率和召回率,它是精确率和召回率的调和平均数:
示例代码为:
1 | import torch |
对于缺陷检测这一特定场景,对上述指标做一个总结。
缺陷检测,关注点就在于“缺陷”,即缺陷是positive
。
(1)召回率
即在一堆有缺陷的样品中,到底有多少样品能被成功检测出缺陷,分母就是“实际有缺陷的样品”,分子就是TP
,即:
(2)漏检率
即在一堆有缺陷的样品中,有多少样品被判定为没有缺陷,分母是“实际有缺陷的样品”,分子就是FN
,即:
即,
漏检率漏检率=1−Recall(3)误检率
即在一堆没有缺陷的样品中,有多少样品被判定为有缺陷,分母是“实际没有缺陷的样品”,分子就是FP
,即:
误检率,又被称为过杀率。
(4)精度
即在一堆被诊断为有缺陷的样品中,到底有多少样品是真的有缺陷,分母就是“所有被诊断为有缺陷的样品”,分子就是TP
,即:
Dice
系数,是一种集合相似度度量函数,通常用于计算两个样本的相似度,具体解释可以参见这里和这里。
基于分类问题,计算公式为:
示例代码为:
1 | import torch |
mAP
,即mean Average Precision
,可翻译为“全类平均精度”,是将所有类别检测的平均正确率(AP
)进行综合加权平均而得到的。而AP
是PR
曲线(精度-召回率曲线)下面积。具体解释可见这里、这里和那里。
示例代码:
1 | import torch |
</div>
<div class="reward-container">
<div style="display: inline-block;">
<img src="/images/wechat_reward.png" alt="Xin-Bo Qi(亓欣波) WeChat Pay">
<p>WeChat Pay</p>
</div>
<footer class="post-footer"> <div class="post-tags"> <a href="/tags/PyTorch/" rel="tag"># PyTorch</a> </div> <div class="post-nav"> <div class="post-nav-item"> <a href="/2022/08/11/neovim_nvchad/" rel="prev" title="Neovim预配置库NvChad探索"> <i class="fa fa-chevron-left"></i> Neovim预配置库NvChad探索 </a></div> <div class="post-nav-item"> <a href="/2022/12/17/dtale/" rel="next" title="Pandas可视化数据分析工具D-Tale详解"> Pandas可视化数据分析工具D-Tale详解 <i class="fa fa-chevron-right"></i> </a></div> </div> </footer>
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。