赞
踩
Metrics是torchmetrics库里的度量类基类,本篇大体介绍一下它是如何工作的
由类的定义可以看到,它继承与两个类,一个是我们熟悉的Module,另外一个是ABC,所以它从行为上来说,跟Module一样
class Metric(Module, ABC):
它的行为同Model,所以通过__call__调用。
所以,第一步是Model的__call__
__call__ : Callable[..., Any] = _call_impl
__call__实际是直接调用_call_impl,这里在727行,直接调用self.forward
同pytorch里的module子类一样,重载forward方法。
Metrics的forward函数,内部定义了update函数和compute函数,所以自定义的Metrics需要重载update和compute
这里有个参数compute_on_step,默认是True。默认情况下,update会在上面一行192行调用一次;然后在204行调用一次。 所以在默认情况下会调用两次。
compute方法仅仅在compute_on_step为True时调用,且在此时才有返回值
每一个Metrics的子类都需要重载这两个函数 (默认compute_on_step=True的情况)
下面以一个自定义的Metrics子类为例
内部用于计算的变量通过add_state注册,然后在update里更新,最后在compute里运算出结果
class MyAccuracy(Metric): def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor): # update metric states self.correct += torch.sum(preds == target) self.total += target.numel() def compute(self): # compute final result return self.correct.float() / self.total metrics = MyAccuracy() preds = torch.tensor([0, 1, 0]) target = torch.tensor([1, 1, 0]) t = metrics(preds, target) print(t)
结果为tensor(0.6667)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。