赞
踩
pip install torchmetrics
conda install -c conda-forge torchmetrics
Welcome to TorchMetrics — PyTorch-Metrics 1.3.1 documentation
https://lightning.ai/docs/torchmetrics/stable/pages/quickstart.html
在训练时我们都是使用微批次训练,对于TorchMetrics也是一样的,在一个批次前向传递完成后将目标值Y和预测值Y_PRED传递给torchmetrics的度量对象,度量对象会计算批次指标并保存它(在其内部被称为state)。
当所有的批次完成时(也就是训练的一个Epoch完成),我们就可以从度量对象返回最终结果(这是对所有批计算的结果)。这里的每个度量对象都是从metric类继承,它包含了4个关键方法
更新度量状态并返回当前批次上计算的度量结果
与forward相同,但是不会返回计算结果,相当于是只将结果存入了state。 如果不需要在当前批处理上计算出的度量结果,则优先使用这个方法,因为他不计算最终结果速度会很快
返回在所有批次上计算的最终结果。
也就是说其实forward相当于是update+compute。
重置状态,以便为下一个验证阶段做好准备。
note: 在训练的当前批次,获得了模型的输出后可以forward或update(建议使用update)。 在批次完成后,调用compute以获取最终结果。最后,在验证轮次(Epoch)或者启用新的轮次进行训练时您调用reset重置状态指标。
- import torch
- # import our library
- import torchmetrics
-
- # initialize metric
- metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=5)
-
- n_batches = 10
- for i in range(n_batches):
- # simulate a classification problem
- preds = torch.randn(10, 5).softmax(dim=-1)
- target = torch.randint(5, (10,))
- # metric on current batch
- acc = metric(preds, target)
- print(f"Accuracy on batch {i}: {acc}")
-
- # metric on all batches using custom accumulation
- acc = metric.compute()
- print(f"Accuracy on all data: {acc}")
-
- # Resetting internal state such that metric ready for new data
- metric.reset()
- import torch
- import torchmetrics
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- model = YourModel().to(device)
- metric = torchmetrics.Accuracy()
-
- for batch_idx, (data, target) in enumerate(val_dataloader):
- data, target = data.to(device), target.to(device)
- output = model(data)
- # metric on current batch
- batch_acc = metric.update(preds, target)
- print(f"Accuracy on batch {i}: {batch_acc}")
-
- # metric on all batches using custom accumulation
- val_acc = metric.compute()
- print(f"Accuracy on all data: {val_acc}")
-
- # Resetting internal state such that metric is ready for new data
- metric.reset()
Torchmetrics提供了MetricCollection可以将多个指标包装成单个可调用类,其接口与上面的基本用法相同。
- import torch
- from torchmetrics import MetricCollection, Accuracy, Precision, Recall
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- model = YourModel().to(device)
- # collection of all validation metrics
- metric_collection = MetricCollection({
- 'acc': Accuracy(),
- 'prec': Precision(num_classes=10, average='macro'),
- 'rec': Recall(num_classes=10, average='macro')
- })
-
- for batch_idx, (data, target) in enumerate(val_dataloader):
- data, target = data.to(device), target.to(device)
- output = model(data)
- batch_metrics = metric_collection.forward(preds, target)
- print(f"Metrics on batch {i}: {batch_metrics}")
-
- val_metrics = metric_collection.compute()
- print(f"Metrics on all data: {val_metrics}")
- metric.reset()
我们只需要继承 Metric 类并且实现 update 和 computing 方法就可以了,另外就是需要在类初始化的时候使用self.add_state(state_name, default)来初始化我们的对象。
- import torch
- import torchmetrics
-
- class MyAccuracy(Metric):
- def __init__(self, delta):
- super().__init__()
- # to count the correct predictions
- self.add_state('corrects', default=torch.tensor(0), dist_reduce_fx="sum")
- # to count the total predictions
- self.add_state('total', default=torch.tensor(0), dist_reduce_fx="sum")
-
- def update(self, preds, target):
- # update correct predictions count
- self.correct += torch.sum(preds == target)
- # update total count, numel() returns the total number of elements
- self.total += target.numel()
-
- def compute(self):
- # final computation
- return self.correct / self.total
代码详解:官方解释更为清楚: https://lightning.ai/docs/torchmetrics/stable/pages/implement.html#implement
The dist_reduce_fx
argument to add_state
is used to specify how the metric states should be reduced between batches in distributed settings. In this case we use "sum"
to sum the metric states across batches. A couple of build in options are available: "sum"
, "mean"
, "cat"
, "min"
or "max"
, but a custom reduction is also supported.
In update
we do not return anything but instead update the metric states in-place.
In compute
when running in distributed mode, the states would have been synced before the compute method is called. Thus self.correct
and self.total
will contain the sum of the metric states across all processes.
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。