当前位置:   article > 正文

【pytorch】Metrics的工作原理_pytorch metric

pytorch metric

Metrics是torchmetrics库里的度量类基类,本篇大体介绍一下它是如何工作的

它也是一个Model

由类的定义可以看到,它继承与两个类,一个是我们熟悉的Module,另外一个是ABC,所以它从行为上来说,跟Module一样

class Metric(Module, ABC):
  • 1

第一步 __call__

它的行为同Model,所以通过__call__调用。

所以,第一步是Model的__call__

__call__ : Callable[..., Any] = _call_impl
  • 1

__call__实际是直接调用_call_impl,这里在727行,直接调用self.forward
调用forward

第二步 forward

pytorch里的module子类一样,重载forward方法。
Metrics的forward函数,内部定义了update函数和compute函数,所以自定义的Metrics需要重载update和compute

forward

这里有个参数compute_on_step,默认是True。默认情况下,update会在上面一行192行调用一次;然后在204行调用一次。 所以在默认情况下会调用两次

compute方法仅仅在compute_on_step为True时调用,且在此时才有返回值

第三步 update 和 compute

每一个Metrics的子类都需要重载这两个函数 (默认compute_on_step=True的情况)
下面以一个自定义的Metrics子类为例

内部用于计算的变量通过add_state注册,然后在update里更新,最后在compute里运算出结果

class MyAccuracy(Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        # update metric states
        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self):
        # compute final result
        return self.correct.float() / self.total


metrics = MyAccuracy()

preds  = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])

t = metrics(preds, target)
print(t)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

结果为tensor(0.6667)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/寸_铁/article/detail/974277
推荐阅读
相关标签
  

闽ICP备14008679号