赞
踩
class Res_com(nn.Module): def __init__(self, n_com=3, b_com=6, d_com=2, com_disable=False): def make(n, b, d): f1 = [2 ** (b + i) for i in range(n)] + [2 ** (b + n - 1 - i) for i in range(n)] f2 = [i * (2 ** d) for i in f1] del f2[len(f2) // 2] f2_last = 32 if f1[0] > 32 else 16 f2.append(f2_last) return f1, f2 self.f1_com, self.f2_com = make(n_com, b_com, d_com) self.n_com = n_com self.com_disable = com_disable def forward(self,x): if self.com_disable: print("No implementation for com_disable") raise else: print("in_channels for class Res_com is :{}".format(x[1])) for i in range(self.n_com * 2): x = Res_block(x[1], [self.f1_com[i], self.f1_com[i], self.f2_com[i]])(x) out = nn.Conv2d(in_channels=x[1],out_channels=12,kernel_size=1,stride=1,padding=0)(x) return out
出错分析
网络中涉及到参数的nn.Conv2d,BatchNorm2d等都被放到了forward,致使程序运行中会认为该网络没有可以更新的权重参数从而出错
解决方案
正确的做法是:将所有涉及到参数更新的层全部以类的变量成员的形式定义在类的初始化方法__init__方法中。
class Res_com(nn.Module): def __init__(self, n_com=3, b_com=6, d_com=2, com_disable=False,in_channels=3): def make(n, b, d): f1 = [2 ** (b + i) for i in range(n)] + [2 ** (b + n - 1 - i) for i in range(n)] f2 = [i * (2 ** d) for i in f1] del f2[len(f2) // 2] f2_last = 32 if f1[0] > 32 else 16 f2.append(f2_last) return f1, f2 self.f1_com, self.f2_com = make(n_com, b_com, d_com) self.n_com = n_com self.com_disable = com_disable layers = [] if self.com_disable: print("No implementation for com_disable") raise else: for i in range(self.n_com * 2): if i == 0 : layers.append(Res_block(in_channels, [self.f1_com[i], self.f1_com[i], self.f2_com[i]])) else: layers.append(Res_block(self.f2_com[i-1], [self.f1_com[i], self.f1_com[i], self.f2_com[i]])) self.multi_res_block = nn.sequential(*layers) self.conv1 = nn.Conv2d(in_channels=x[1],out_channels=12,kernel_size=1,stride=1,padding=0)(x) def forward(self,x): print("in_channels for class Res_com is :{}".format(x[1])) out = self.multi_res_block(x) out = self.conv1(out) return out
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。