当前位置:   article > 正文

Pytorch用ConvTranspose2d替代Upsample

Pytorch用ConvTranspose2d替代Upsample

本文介绍了Pytorch如何用ConvTranspose2d算子等价替代Upsample算子。

背景介绍:

  • 某些AI加速卡上Upsample算子的性能不够高,是否能用别的算子临时替代呢
  • 可以手动推断出ConvTranspose2d 的权值,使其与Upsample等价算子
  • 也可以搭建一个模型,输入分别给到ConvTranspose2d和Upsample算子,使它们之间的L1Loss最小
  • 当网络收敛后,对ConvTranspose2d的权值做舍入处理
  • 最后用上面的权值初始化ConvTranspose2d

网络结构

import onnx
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import numpy as np

class UpsampleModel(torch.nn.Module):
    def __init__(self):
        super(UpsampleModel, self).__init__()
        self.up=nn.Upsample(scale_factor=2, mode='nearest')
        self.deconv1=nn.ConvTranspose2d(3,3,2,2,groups=1,bias=False)     
    def forward(self, x):
        out0=self.up(x)
        out1=self.deconv1(x)
        return out0,out1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

训练ConvTranspose2d的权值

def train():
    input_shape = (1, 3, 224, 224)
    model = UpsampleModel()
    criterion = nn.L1Loss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(2100):
        running_loss = 0.0
        for i in range(100):
            input_data = torch.randn(input_shape)
            optimizer.zero_grad()
            out0,out1=model(input_data)
            loss = criterion(out0,out1)
            loss.backward() 
            optimizer.step() 
            running_loss += loss.item()        
        avg_loss=running_loss / 100
        print('[%d] loss: %f' % (epoch + 1,avg_loss ))
        running_loss = 0.0
        if avg_loss<1e-4:            
            w=model.deconv1.weight.detach().numpy()
            #print(w)       
            print(np.round(w))    
            break
train()            
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

结果

[[[[ 1.  1.]
   [ 1.  1.]]
  [[-0. -0.]
   [-0. -0.]]
  [[ 0. -0.]
   [-0. -0.]]]
 [[[ 0.  0.]
   [ 0. -0.]]
  [[ 1.  1.]
   [ 1.  1.]]
  [[ 0.  0.]
   [ 0. -0.]]]
 [[[ 0. -0.]
   [-0. -0.]]
  [[ 0. -0.]
   [ 0.  0.]]
  [[ 1.  1.]
   [ 1.  1.]]]]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

用上面生成的权值验证

def val():
    w=np.array(
        [[[[ 1. , 1.],
           [ 1. , 1.]],
          [[ 0. , 0.],
           [ 0. , 0.]],
          [[ 0. , 0.],
           [ 0. , 0.]]],
         [[[ 0. , 0.],
           [ 0. , 0.]],
          [[ 1. , 1.],
           [ 1. , 1.]],
          [[ 0. , 0.],
           [ 0. , 0.]]],
         [[[ 0. , 0.],
           [ 0. , 0.]],
          [[ 0. , 0.],
           [ 0. , 0.]],
          [[ 1. , 1.],
           [ 1. , 1.]]]]   
            )
    input_shape = (1, 3, 224, 224)
    model = UpsampleModel().eval()
    model.deconv1.weight=torch.nn.Parameter(torch.from_numpy(w.astype(np.float32))) #设置权值
    input_data = torch.randn(input_shape)
    out0,out1=model(input_data)
    out0=out0.detach().numpy().reshape(-1)
    out1=out1.detach().numpy().reshape(-1)
    ret=(out0==out1).all()
val()    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

输出

True
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/149420?site
推荐阅读
相关标签
  

闽ICP备14008679号