当前位置:   article > 正文

03 pytorch 验证指标工具 torchmetrics_torchmetrics安装包

torchmetrics安装包

安装torchmetircs

pip install torchmetrics
conda install -c conda-forge torchmetrics

官方源码

Welcome to TorchMetrics — PyTorch-Metrics 1.3.1 documentation

https://lightning.ai/docs/torchmetrics/stable/pages/quickstart.html

基本函数

基本流程

在训练时我们都是使用微批次训练,对于TorchMetrics也是一样的,在一个批次前向传递完成后将目标值Y和预测值Y_PRED传递给torchmetrics的度量对象,度量对象会计算批次指标并保存它(在其内部被称为state)。

当所有的批次完成时(也就是训练的一个Epoch完成),我们就可以从度量对象返回最终结果(这是对所有批计算的结果)。这里的每个度量对象都是从metric类继承,它包含了4个关键方法

metrics.forward(pred,target)

更新度量状态并返回当前批次上计算的度量结果

metric.update(pred,target)

与forward相同,但是不会返回计算结果,相当于是只将结果存入了state。 如果不需要在当前批处理上计算出的度量结果,则优先使用这个方法,因为他不计算最终结果速度会很快

metric.compute()

 返回在所有批次上计算的最终结果。

也就是说其实forward相当于是update+compute。

metric.reset()

重置状态,以便为下一个验证阶段做好准备。

note: 在训练的当前批次,获得了模型的输出后可以forward或update(建议使用update)。 在批次完成后,调用compute以获取最终结果。最后,在验证轮次(Epoch)或者启用新的轮次进行训练时您调用reset重置状态指标。

使用指南

单个指标

API接口调用

  1. import torch
  2. # import our library
  3. import torchmetrics
  4. # initialize metric
  5. metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5)
  6. n_batches = 10
  7. for i in range(n_batches):
  8. # simulate a classification problem
  9. preds = torch.randn(10, 5).softmax(dim=-1)
  10. target = torch.randint(5, (10,))
  11. # metric on current batch
  12. acc = metric(preds, target)
  13. print(f"Accuracy on batch {i}: {acc}")
  14. # metric on all batches using custom accumulation
  15. acc = metric.compute()
  16. print(f"Accuracy on all data: {acc}")
  17. # Resetting internal state such that metric ready for new data
  18. metric.reset()

深度学习训练模型时,可参考下述流程

  1. import torch
  2. import torchmetrics
  3. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  4. model = YourModel().to(device)
  5. metric = torchmetrics.Accuracy()
  6. for batch_idx, (data, target) in enumerate(val_dataloader):
  7. data, target = data.to(device), target.to(device)
  8. output = model(data)
  9. # metric on current batch
  10. batch_acc = metric.update(preds, target)
  11. print(f"Accuracy on batch {i}: {batch_acc}")
  12. # metric on all batches using custom accumulation
  13. val_acc = metric.compute()
  14. print(f"Accuracy on all data: {val_acc}")
  15. # Resetting internal state such that metric is ready for new data
  16. metric.reset()

多个指标

Torchmetrics提供了MetricCollection可以将多个指标包装成单个可调用类,其接口与上面的基本用法相同。

  1. import torch
  2. from torchmetrics import MetricCollection, Accuracy, Precision, Recall
  3. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  4. model = YourModel().to(device)
  5. # collection of all validation metrics
  6. metric_collection = MetricCollection({
  7. 'acc': Accuracy(),
  8. 'prec': Precision(num_classes=10, average='macro'),
  9. 'rec': Recall(num_classes=10, average='macro')
  10. })
  11. for batch_idx, (data, target) in enumerate(val_dataloader):
  12. data, target = data.to(device), target.to(device)
  13. output = model(data)
  14. batch_metrics = metric_collection.forward(preds, target)
  15. print(f"Metrics on batch {i}: {batch_metrics}")
  16. val_metrics = metric_collection.compute()
  17. print(f"Metrics on all data: {val_metrics}")
  18. metric.reset()

创建自定义矩阵评估指标

我们只需要继承 Metric 类并且实现 update 和 computing 方法就可以了,另外就是需要在类初始化的时候使用self.add_state(state_name, default)来初始化我们的对象。

  1. __init__ :self.add_state 用于度量计算所需的每个内部状态
  2. update :更新度量状态所需的所有逻辑
  3. compute:实现计算方法,最终的度量计算
  1. import torch
  2. import torchmetrics
  3. class MyAccuracy(Metric):
  4. def __init__(self, delta):
  5. super().__init__()
  6. # to count the correct predictions
  7. self.add_state('corrects', default=torch.tensor(0), dist_reduce_fx="sum")
  8. # to count the total predictions
  9. self.add_state('total', default=torch.tensor(0), dist_reduce_fx="sum")
  10. def update(self, preds, target):
  11. # update correct predictions count
  12. self.correct += torch.sum(preds == target)
  13. # update total count, numel() returns the total number of elements
  14. self.total += target.numel()
  15. def compute(self):
  16. # final computation
  17. return self.correct / self.total

代码详解:官方解释更为清楚: https://lightning.ai/docs/torchmetrics/stable/pages/implement.html#implement

  • The dist_reduce_fx argument to add_state is used to specify how the metric states should be reduced between batches in distributed settings. In this case we use "sum" to sum the metric states across batches. A couple of build in options are available: "sum""mean""cat""min" or "max", but a custom reduction is also supported.

  • In update we do not return anything but instead update the metric states in-place.

  • In compute when running in distributed mode, the states would have been synced before the compute method is called. Thus self.correct and self.total will contain the sum of the metric states across all processes.

参考

使用Torchmetrics快速进行验证指标的计算 - 知乎

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/974285
推荐阅读
相关标签
  

闽ICP备14008679号