当前位置:   article > 正文

编程速记(27): Pytorch篇-纠错'Res_rec' object has no attribute '_parameters'——基于nn.Module的网络的搭建_object has no attribute 'paramters

object has no attribute 'paramters

一、报错:‘Res_rec’ object has no attribute ‘_parameters’

在这里插入图片描述

二、纠错前代码

class Res_com(nn.Module):
    def __init__(self, n_com=3,  b_com=6, d_com=2, com_disable=False):
        def make(n, b, d):
            f1 = [2 ** (b + i) for i in range(n)] + [2 ** (b + n - 1 - i) for i in range(n)]
            f2 = [i * (2 ** d) for i in f1]
            del f2[len(f2) // 2]
            f2_last = 32 if f1[0] > 32 else 16
            f2.append(f2_last)
            return f1, f2

        self.f1_com, self.f2_com = make(n_com, b_com, d_com)
        self.n_com = n_com
        self.com_disable = com_disable
    def forward(self,x):
        if self.com_disable:
            print("No implementation for com_disable")
            raise
        else:
            print("in_channels for class Res_com is :{}".format(x[1]))
            for i in range(self.n_com * 2):
                x = Res_block(x[1], [self.f1_com[i], self.f1_com[i], self.f2_com[i]])(x)
            out = nn.Conv2d(in_channels=x[1],out_channels=12,kernel_size=1,stride=1,padding=0)(x)
            return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

出错分析

网络中涉及到参数的nn.Conv2d,BatchNorm2d等都被放到了forward,致使程序运行中会认为该网络没有可以更新的权重参数从而出错

解决方案

正确的做法是:将所有涉及到参数更新的层全部以类的变量成员的形式定义在类的初始化方法__init__方法中。

三、纠错后代码

class Res_com(nn.Module):
    def __init__(self, n_com=3,  b_com=6, d_com=2, com_disable=False,in_channels=3):
        def make(n, b, d):
            f1 = [2 ** (b + i) for i in range(n)] + [2 ** (b + n - 1 - i) for i in range(n)]
            f2 = [i * (2 ** d) for i in f1]
            del f2[len(f2) // 2]
            f2_last = 32 if f1[0] > 32 else 16
            f2.append(f2_last)
            return f1, f2

        self.f1_com, self.f2_com = make(n_com, b_com, d_com)
        self.n_com = n_com
        self.com_disable = com_disable

        layers = []
        if self.com_disable:
            print("No implementation for com_disable")
            raise
        else:
            for i in range(self.n_com * 2):
                if i == 0 :
                    layers.append(Res_block(in_channels, [self.f1_com[i], self.f1_com[i], self.f2_com[i]]))
                else:
                    layers.append(Res_block(self.f2_com[i-1], [self.f1_com[i], self.f1_com[i], self.f2_com[i]]))
            self.multi_res_block = nn.sequential(*layers)
            self.conv1 = nn.Conv2d(in_channels=x[1],out_channels=12,kernel_size=1,stride=1,padding=0)(x)

    def forward(self,x):
        print("in_channels for class Res_com is :{}".format(x[1]))
        out = self.multi_res_block(x)
        out = self.conv1(out)
        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/秋刀鱼在做梦/article/detail/948363
推荐阅读
相关标签
  

闽ICP备14008679号