赞
踩
在用 torch.nn.Upsample 给分割 label 上采样时报错:RuntimeError: "upsample_nearest2d_out_frame" not implemented for 'Long'
。
参考 [1-3],用 [3] 给出的实现。稍微扩展一下,支持 h、w 用不同的 scale factor,并测试其与 PyTorch 的几个 upsample 类的异同,验证 [3] 的实现用 nearest 插值。
linear
要 3D 输入、trilinear
要 5D 输入,故此两种插值法没比。import torch import torch.nn as nn class UpsampleDeterministic(nn.Module): """deterministic upsample with `nearest` interpolation""" def __init__(self, scale_factor=2): """ Input: scale_factor: int or (int, int), ratio to scale (along heigth & width) """ super(UpsampleDeterministic, self).__init__() if isinstance(scale_factor, (tuple, list)): assert len(scale_factor) == 2 self.scale_h, self.scale_w = scale_factor else: self.scale_h = self.scale_w = scale_factor assert isinstance(self.scale_h, int) and isinstance(self.scale_w, int) def forward(self, x): """ Input: x: [n, c, h, w], torch.Tensor Output: upsampled x': [n, c, h * scale_h, w * scale_w] """ return x[:, :, :, None, :, None].expand( -1, -1, -1, self.scale_h, -1, self.scale_w).reshape( x.size(0), x.size(1), x.size(2) * self.scale_h, x.size(3) * self.scale_w) # 随机数据 x = torch.rand(2, 3, 4, 4) # [n, c, h, w] # [3] 的实现 us_det = UpsampleDeterministic((2, 3)) # pytorch 自带的几种实现 us_list = {mode: nn.Upsample(scale_factor=(2, 3), mode=mode) for mode in ('nearest', 'bilinear', 'bicubic')} # linear: 3D # trilinear: 5D y_det = us_det(x) print(y_det.size()) for us_name, us in us_list.items(): y = us(x) print(us_name, y.size(), (y_det != y).sum())
输出:
torch.Size([2, 3, 8, 12])
nearest torch.Size([2, 3, 8, 12]) tensor(0)
bilinear torch.Size([2, 3, 8, 12]) tensor(507)
bicubic torch.Size([2, 3, 8, 12]) tensor(576)
可见 [3] 的实现与 nearest 结果一致。
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
赞
踩
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。