当前位置:   article > 正文

pytorch nearest upsample整数型tensor

pytorch nearest upsample整数型tensor

在用 torch.nn.Upsample 给分割 label 上采样时报错:RuntimeError: "upsample_nearest2d_out_frame" not implemented for 'Long'

参考 [1-3],用 [3] 给出的实现。稍微扩展一下,支持 h、w 用不同的 scale factor,并测试其与 PyTorch 的几个 upsample 类的异同,验证 [3] 的实现用 nearest 插值。

Code

  • linear 要 3D 输入、trilinear 要 5D 输入,故此两种插值法没比。
import torch
import torch.nn as nn


class UpsampleDeterministic(nn.Module):
    """deterministic upsample with `nearest` interpolation"""

    def __init__(self, scale_factor=2):
        """
        Input:
            scale_factor: int or (int, int), ratio to scale (along heigth & width)
        """
        super(UpsampleDeterministic, self).__init__()
        if isinstance(scale_factor, (tuple, list)):
            assert len(scale_factor) == 2
            self.scale_h, self.scale_w = scale_factor
        else:
            self.scale_h = self.scale_w = scale_factor
        assert isinstance(self.scale_h, int) and isinstance(self.scale_w, int)

    def forward(self, x):
        """
        Input:
            x: [n, c, h, w], torch.Tensor
        Output:
            upsampled x': [n, c, h * scale_h, w * scale_w]
        """
        return x[:, :, :, None, :, None].expand(
            -1, -1, -1, self.scale_h, -1, self.scale_w).reshape(
                x.size(0), x.size(1), x.size(2) * self.scale_h, x.size(3) * self.scale_w)


# 随机数据
x = torch.rand(2, 3, 4, 4) # [n, c, h, w]
# [3] 的实现
us_det = UpsampleDeterministic((2, 3))
# pytorch 自带的几种实现
us_list = {mode: nn.Upsample(scale_factor=(2, 3), mode=mode)
           for mode in ('nearest', 'bilinear', 'bicubic')}
# linear: 3D
# trilinear: 5D

y_det = us_det(x)
print(y_det.size())
for us_name, us in us_list.items():
    y = us(x)
    print(us_name, y.size(), (y_det != y).sum())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

输出:

torch.Size([2, 3, 8, 12])
nearest torch.Size([2, 3, 8, 12]) tensor(0)
bilinear torch.Size([2, 3, 8, 12]) tensor(507)
bicubic torch.Size([2, 3, 8, 12]) tensor(576)
  • 1
  • 2
  • 3
  • 4

可见 [3] 的实现与 nearest 结果一致。

References

  1. 请慎用torch.nn.Upsample
  2. PyTorch中模型的可复现性
  3. Non Deterministic Behaviour even after cudnn.deterministic = True and cudnn.benchmark=False #12207
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/blog/article/detail/55535
推荐阅读
相关标签
  

闽ICP备14008679号