当前位置:   article > 正文

senet模型代码解读_senet代码

senet代码

网络架构图:

准备模型:

  1. model_name = 'se_resnext101_32x4d'
  2. model = MODEL( num_classes= 500 , senet154_weight = WEIGHT_PATH, multi_scale = True, learn_region=True)
  3. model = torch.nn.DataParallel(model)
  4. vgg16 = model
  5. vgg16.load_state_dict(torch.load('./model/ISIAfood500.pth'))

 Senet模型代码:

  1. """这段代码定义了一个名为senet154的函数,它使用SENet模型来进行图像分类。
  2. """
  3. def senet154(num_classes=1000, pretrained='imagenet'):
  4. model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
  5. dropout_p=0.2, num_classes=num_classes)
  6. if pretrained is not None:
  7. settings = pretrained_settings['senet154'][pretrained]
  8. initialize_pretrained_model(model, num_classes, settings)
  9. return model
  10. """SENet是一种卷积神经网络,它使用SEBottleneck块来增强特征表示。这个函数
  11. 使用了一个包含四个元素的列表来定义SENet的结构,其中每个元素表示一个阶段,
  12. 每个阶段包含多个SEBottleneck块。groups参数指定了SEBottleneck块中的卷积分组数,
  13. reduction参数指定了SE块中的通道缩减比例。如果pretrained参数不为None,则会使用
  14. 预训练的权重来初始化模型。预训练的权重存储在pretrainedsettings字典中,
  15. 可以通过指定pretrained参数来选择不同的预训练权重。最后,函数返回SENet模型。"""
  16. class SENet(nn.Module):
  17. def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
  18. inplanes=128, input_3x3=True, downsample_kernel_size=3,
  19. super(SENet, self).__init__()
  20. self.inplanes = inplanes
  21. if input_3x3:
  22. layer0_modules = [
  23. ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
  24. bias=False)),
  25. ('bn1', nn.BatchNorm2d(64)),
  26. ('relu1', nn.ReLU(inplace=True)), # 从这 224 -> 112 stride =2
  27. ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
  28. bias=False)),
  29. ('bn2', nn.BatchNorm2d(64)),
  30. ('relu2', nn.ReLU(inplace=True)),
  31. ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
  32. bias=False)),
  33. ('bn3', nn.BatchNorm2d(inplanes)),
  34. ('relu3', nn.ReLU(inplace=True)), # 输出的是 128 * 112* 112
  35. ]
  36. else:
  37. layer0_modules = [
  38. ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
  39. padding=3, bias=False)),
  40. ('bn1', nn.BatchNorm2d(inplanes)),
  41. ('relu1', nn.ReLU(inplace=True)),
  42. ]
  43. # To preserve compatibility with Caffe weights `ceil_mode=True`
  44. # is used instead of `padding=1`.
  45. layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
  46. ceil_mode=True))) # 这个 就 变成了 112 -> 56
  47. self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) # output 128 * 56 * 56
  48. self.layer1 = self._make_layer(
  49. block,
  50. planes=64,
  51. blocks=layers[0],
  52. groups=groups,
  53. reduction=reduction,
  54. downsample_kernel_size=1,
  55. downsample_padding=0 # layer 1 不会降尺寸。 但是会改变通道。 所以输出是256 * 56 *56
  56. )
  57. self.layer2 = self._make_layer(
  58. block,
  59. planes=128,
  60. blocks=layers[1],
  61. stride=2,
  62. groups=groups,
  63. reduction=reduction,
  64. downsample_kernel_size=downsample_kernel_size,
  65. downsample_padding=downsample_padding # layer 2 降尺寸。 因为stride =2 要进行降采样。 输出就是 512 * 28 * 28
  66. )
  67. self.layer3 = self._make_layer(
  68. block,
  69. planes=256,
  70. blocks=layers[2],
  71. stride=2,
  72. groups=groups,
  73. reduction=reduction,
  74. downsample_kernel_size=downsample_kernel_size,
  75. downsample_padding=downsample_padding # layer 3 降尺寸。 因为stride =2 要进行降采样。 输出就是 1024 * 14 * 14
  76. )
  77. self.layer4 = self._make_layer(
  78. block,
  79. planes=512,
  80. blocks=layers[3],
  81. stride=2,
  82. groups=groups,
  83. reduction=reduction,
  84. downsample_kernel_size=downsample_kernel_size, # layer 4 降尺寸。 因为stride =2 要进行降采样。 输出就是 2048 * 7 * 7
  85. downsample_padding=downsample_padding
  86. )
  87. self.avg_pool = nn.AvgPool2d(7, stride=1)
  88. self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
  89. self.last_linear = nn.Linear(512 * block.expansion, num_classes)
  90. def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
  91. downsample_kernel_size=1, downsample_padding=0):
  92. downsample = None
  93. if stride != 1 or self.inplanes != planes * block.expansion:
  94. downsample = nn.Sequential(
  95. nn.Conv2d(self.inplanes, planes * block.expansion,
  96. kernel_size=downsample_kernel_size, stride=stride,
  97. padding=downsample_padding, bias=False),
  98. nn.BatchNorm2d(planes * block.expansion),
  99. )
  100. layers = []
  101. layers.append(block(self.inplanes, planes, groups, reduction, stride,
  102. downsample))
  103. self.inplanes = planes * block.expansion
  104. for i in range(1, blocks):
  105. layers.append(block(self.inplanes, planes, groups, reduction))
  106. return nn.Sequential(*layers)
  107. def features(self, x):
  108. x = self.layer0(x)
  109. x = self.layer1(x)
  110. x = self.layer2(x)
  111. x = self.layer3(x)
  112. x = self.layer4(x)
  113. return x
  114. def logits(self, x):
  115. x = self.avg_pool(x)
  116. if self.dropout is not None:
  117. x = self.dropout(x)
  118. x = x.view(x.size(0), -1)
  119. x = self.last_linear(x)
  120. return x
  121. def forward(self, x):
  122. x = self.features(x)
  123. x = self.logits(x)
  124. return x

模型构建代码:

  1. class ConvBlock(nn.Module):
  2. """基本卷积块。
  3. 卷积 + 批量归一化 + relu。
  4. Args:
  5. in_c (int): 输入通道数。
  6. out_c (int): 输出通道数。
  7. k (int or tuple): 卷积核大小。
  8. s (int or tuple): 步长。
  9. p (int or tuple): 填充。
  10. """
  11. def __init__(self, in_c, out_c, k, s=1, p=0):
  12. super(ConvBlock, self).__init__()
  13. self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
  14. self.bn = nn.BatchNorm2d(out_c)
  15. def forward(self, x):
  16. return F.relu(self.bn(self.conv(x)))
  17. # 定义InceptionA模块
  18. class InceptionA(nn.Module):
  19. def __init__(self, in_channels, out_channels):
  20. super(InceptionA, self).__init__()
  21. mid_channels = out_channels // 4
  22. # 第一个分支
  23. self.stream1 = nn.Sequential(
  24. ConvBlock(in_channels, mid_channels, 1), # 1x1卷积
  25. ConvBlock(mid_channels, mid_channels, 3, p=1), # 3x3卷积
  26. )
  27. # 第二个分支
  28. self.stream2 = nn.Sequential(
  29. ConvBlock(in_channels, mid_channels, 1), # 1x1卷积
  30. ConvBlock(mid_channels, mid_channels, 3, p=1), # 3x3卷积
  31. )
  32. # 第三个分支
  33. self.stream3 = nn.Sequential(
  34. ConvBlock(in_channels, mid_channels, 1), # 1x1卷积
  35. ConvBlock(mid_channels, mid_channels, 3, p=1), # 3x3卷积
  36. )
  37. # 第四个分支
  38. self.stream4 = nn.Sequential(
  39. nn.AvgPool2d(3, stride=1, padding=1), # 平均池化
  40. ConvBlock(in_channels, mid_channels, 1), # 1x1卷积
  41. )
  42. def forward(self, x):
  43. # 将四个分支的输出拼接在一起
  44. s1 = self.stream1(x)
  45. s2 = self.stream2(x)
  46. s3 = self.stream3(x)
  47. s4 = self.stream4(x)
  48. y = torch.cat([s1, s2, s3, s4], dim=1)
  49. return y
  50. # 定义InceptionB模块
  51. class InceptionB(nn.Module):
  52. def __init__(self, in_channels, out_channels):
  53. super(InceptionB, self).__init__()
  54. mid_channels = out_channels // 4
  55. # 第一个分支
  56. self.stream1 = nn.Sequential(
  57. ConvBlock(in_channels, mid_channels, 1), # 1x1卷积
  58. ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), # 3x3卷积
  59. )
  60. # 第二个分支
  61. self.stream2 = nn.Sequential(
  62. ConvBlock(in_channels, mid_channels, 1), # 1x1卷积
  63. ConvBlock(mid_channels, mid_channels, 3, p=1), # 3x3卷积
  64. ConvBlock(mid_channels, mid_channels, 3, s=2, p=1), # 3x3卷积
  65. )
  66. # 第三个分支
  67. self.stream3 = nn.Sequential(
  68. nn.MaxPool2d(3, stride=2, padding=1), # 最大池化
  69. ConvBlock(in_channels, mid_channels*2, 1), # 1x1卷积
  70. )
  71. def forward(self, x):
  72. # 分别对三个分支进行计算
  73. s1 = self.stream1(x)
  74. s2 = self.stream2(x)
  75. s3 = self.stream3(x)
  76. # 将三个分支的结果进行拼接
  77. y = torch.cat([s1, s2, s3], dim=1)
  78. return y
  79. class SpatialAttn(nn.Module):
  80. """Spatial Attention (Sec. 3.1.I.1)"""
  81. def __init__(self):
  82. super(SpatialAttn, self).__init__()
  83. self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
  84. self.conv2 = ConvBlock(1, 1, 1)
  85. def forward(self, x):
  86. # global cross-channel averaging
  87. x = x.mean(1, keepdim=True) # 由hwc 变为 hw1
  88. # 3-by-3 conv
  89. h = x.size(2)
  90. x = self.conv1(x)
  91. # bilinear resizing
  92. x = F.upsample(x, (h,h), mode='bilinear', align_corners=True)
  93. # scaling conv
  94. x = self.conv2(x)
  95. return x
  96. ## 返回的是h*w*1 的 soft map
  97. class ChannelAttn(nn.Module):
  98. """通道注意力机制"""
  99. def __init__(self, in_channels, reduction_rate=16):
  100. super(ChannelAttn, self).__init__()
  101. assert in_channels%reduction_rate == 0
  102. self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
  103. self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)
  104. def forward(self, x):
  105. # 压缩操作(全局平均池化)
  106. x = F.avg_pool2d(x, x.size()[2:])
  107. # 激励操作(2个卷积层)
  108. x = self.conv1(x)
  109. x = self.conv2(x)
  110. return x
  111. '''
  112. 空间和通道上的attention 融合
  113. 就是空间和通道上的attention做一个矩阵乘法
  114. '''
  115. class SoftAttn(nn.Module):
  116. """Soft Attention (Sec. 3.1.I)
  117. Aim: Spatial Attention + Channel Attention
  118. Output: attention maps with shape identical to input.
  119. """
  120. def __init__(self, in_channels):
  121. super(SoftAttn, self).__init__()
  122. self.spatial_attn = SpatialAttn()
  123. self.channel_attn = ChannelAttn(in_channels)
  124. self.conv = ConvBlock(in_channels, in_channels, 1)
  125. def forward(self, x):
  126. y_spatial = self.spatial_attn(x) # 空间注意力输出
  127. y_channel = self.channel_attn(x) # 通道注意力输出
  128. y = y_spatial * y_channel # 空间注意力和通道注意力相乘
  129. y = torch.sigmoid(self.conv(y)) # 卷积块输出
  130. return y
  131. '''
  132. 输出的是STN 需要的theta
  133. '''
  134. class HardAttn(nn.Module):
  135. """Hard Attention (Sec. 3.1.II)"""
  136. def __init__(self, in_channels):
  137. super(HardAttn, self).__init__()
  138. self.fc = nn.Linear(in_channels, 4*2)
  139. self.init_params()
  140. def init_params(self):
  141. self.fc.weight.data.zero_()
  142. # 初始化 参数
  143. # if x_t = 0 the performance is very low
  144. self.fc.bias.data.copy_(torch.tensor([0.3, -0.3, 0.3, 0.3, -0.3, 0.3, -0.3, -0.3], dtype=torch.float))
  145. def forward(self, x):
  146. # squeeze operation (global average pooling)
  147. x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
  148. # predict transformation parameters
  149. theta = torch.tanh(self.fc(x))
  150. theta = theta.view(-1, 4, 2)
  151. return theta
  152. # 返回的是 2T T为区域数量。 因为尺度会固定。 所以只要学位移的值
  153. class HarmAttn(nn.Module):
  154. """Harmonious Attention (Sec. 3.1)"""
  155. # 定义一个名为HarmAttn的类,继承自nn.Module类,表示这是一个神经网络模型
  156. def __init__(self, in_channels):
  157. super(HarmAttn, self).__init__()
  158. # 调用父类的构造函数,初始化神经网络模型
  159. self.soft_attn = SoftAttn(in_channels)
  160. # 定义一个名为soft_attn的属性,其值为SoftAttn(in_channels),表示该属性是一个软注意力机制
  161. self.hard_attn = HardAttn(in_channels)
  162. # 定义一个名为hard_attn的属性,其值为HardAttn(in_channels),表示该属性是一个硬注意力机制
  163. def forward(self, x):
  164. # 定义一个名为forward的函数,表示前向传播过程
  165. y_soft_attn = self.soft_attn(x)
  166. # 定义一个名为y_soft_attn的变量,其值为self.soft_attn(x),表示使用软注意力机制对输入x进行处理
  167. theta = self.hard_attn(x)
  168. # 定义一个名为theta的变量,其值为self.hard_attn(x),表示使用硬注意力机制对输入x进行处理
  169. return y_soft_attn, theta
  170. class MODEL(nn.Module):
  171. '''
  172. cvper2020的主模型
  173. '''
  174. def __init__(self, num_classes, senet154_weight, nchannels=[256,512,1024,2048], multi_scale = False ,learn_region=True, use_gpu=True):
  175. super(MODEL,self).__init__()
  176. self.learn_region=learn_region
  177. self.use_gpu = use_gpu
  178. self.conv = ConvBlock(3, 32, 3, s=2, p=1)
  179. self.senet154_weight = senet154_weight
  180. self.multi_scale = multi_scale
  181. self.num_classes = num_classes
  182. # 构建SEnet154
  183. senet154_ = senet154(num_classes=1000, pretrained=None)
  184. senet154_.load_state_dict(torch.load(self.senet154_weight))
  185. self.extract_feature = senet154_.layer0
  186. #全局backbone
  187. self.global_layer1 = senet154_.layer1
  188. self.global_layer2 = senet154_.layer2
  189. self.global_layer3 = senet154_.layer3
  190. self.global_layer4 = senet154_.layer4
  191. self.classifier_global =nn.Sequential(
  192. nn.Linear(2048*2, 2048), # 将4个区域 融合成一个 需要加上batchnorma1d, 和 relu
  193. nn.BatchNorm1d(2048),
  194. nn.ReLU(),
  195. nn.Dropout(0.2),
  196. nn.Linear(2048, num_classes),
  197. )
  198. if self.multi_scale:
  199. self.global_fc = nn.Sequential(
  200. nn.Linear(2048+512+1024, 2048), # 将4个区域 融合成一个 需要加上batchnorma1d, 和 relu
  201. nn.BatchNorm1d(2048),
  202. nn.ReLU(),
  203. )
  204. self.global_out = nn.Linear(2048,num_classes) # global 分类
  205. else:
  206. self.global_out = nn.Linear(2048,num_classes) # global 分类
  207. self.ha2 = HarmAttn(nchannels[1])
  208. self.ha3 = HarmAttn(nchannels[2])
  209. self.ha4 = HarmAttn(nchannels[3])
  210. self.dropout = nn.Dropout(0.2) # 分类层之前使用dropout
  211. if self.learn_region:
  212. self.init_scale_factors()
  213. self.local_conv1 = InceptionB(nchannels[1], nchannels[1])
  214. self.local_conv2 = InceptionB(nchannels[2], nchannels[2])
  215. self.local_conv3 = InceptionB(nchannels[3], nchannels[3])
  216. self.local_fc = nn.Sequential(
  217. nn.Linear(2048+512+1024, 2048), # 将4个区域 融合成一个 需要加上batchnorma1d, 和 relu
  218. nn.BatchNorm1d(2048),
  219. nn.ReLU(),
  220. )
  221. self.classifier_local = nn.Linear(2048,num_classes)
  222. def init_scale_factors(self):
  223. # 初始化四个区域的缩放因子(s_w,s_h)
  224. # s_w和s_h是固定的。
  225. self.scale_factors = []
  226. self.scale_factors.append(torch.tensor([[0.5, 0], [0, 0.5]], dtype=torch.float))
  227. self.scale_factors.append(torch.tensor([[0.5, 0], [0, 0.5]], dtype=torch.float))
  228. self.scale_factors.append(torch.tensor([[0.5, 0], [0, 0.5]], dtype=torch.float))
  229. self.scale_factors.append(torch.tensor([[0.5, 0], [0, 0.5]], dtype=torch.float))
  230. def stn(self, x, theta):
  231. """执行空间变换
  232. x: (batch, channel, height, width)
  233. theta: (batch, 2, 3)
  234. """
  235. grid = F.affine_grid(theta, x.size())
  236. x = F.grid_sample(x, grid)
  237. return x
  238. def transform_theta(self, theta_i, region_idx):
  239. """将theta转换为包括(s_w,s_h)的形式,结果为(batch,2,3)"""
  240. scale_factors = self.scale_factors[region_idx]
  241. theta = torch.zeros(theta_i.size(0), 2, 3)
  242. theta[:,:,:2] = scale_factors
  243. theta[:,:,-1] = theta_i
  244. if self.use_gpu: theta = theta.cuda()
  245. return theta
  246. def forward(self, x):
  247. batch_size = x.size()[0] # 获取批量大小
  248. x = self.extract_feature(x) # 输出形状为128 * 56 *56 senet154第0层layer 提取特征
  249. # =================layer 1 ===============
  250. # 全局分支
  251. x1 = self.global_layer1(x) # 输出形状为256*56*56
  252. #============layer 2================
  253. #全局分支
  254. x2 = self.global_layer2(x1) # x2是512*28*28
  255. x2_attn, x2_theta = self.ha2(x2)
  256. x2_out = x2 * x2_attn
  257. if self.multi_scale:
  258. # attention global layer1 avg pooling
  259. x2_avg = F
  260. x2_avg = F.adaptive_avg_pool2d(x2_out, (1, 1)).view(x2_out.size(0), -1) #512 向量
  261. # local branch
  262. if self.learn_region:
  263. x2_local_list = []
  264. for region_idx in range(4):
  265. x2_theta_i = x2_theta[:,region_idx,:]
  266. x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
  267. x2_trans_i = self.stn(x2, x2_theta_i) #256*56*26
  268. x2_trans_i = F.upsample(x2_trans_i, (56, 56), mode='bilinear', align_corners=True)
  269. x2_local_i = x2_trans_i
  270. x2_local_i = self.local_conv1(x2_local_i) # 512*28*28
  271. x2_local_list.append(x2_local_i)
  272. #============layer 3================
  273. #global branch
  274. x3 = self.global_layer3(x2_out) # x3 is 1024*14*14
  275. # print('layer3 output')
  276. # print(x3.size())
  277. x3_attn, x3_theta = self.ha3(x3)
  278. x3_out = x3 * x3_attn
  279. if self.multi_scale:
  280. # attention global layer1 avg pooling
  281. x3_avg = F.adaptive_avg_pool2d(x3_out, (1, 1)).view(x3_out.size(0), -1) #1024 向量
  282. # local branch
  283. if self.learn_region:
  284. x3_local_list = []
  285. for region_idx in range(4):
  286. x3_theta_i = x3_theta[:,region_idx,:]
  287. x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
  288. x3_trans_i = self.stn(x3, x3_theta_i) #512*28*28
  289. x3_trans_i = F.upsample(x3_trans_i, (28, 28), mode='bilinear', align_corners=True)
  290. x3_local_i = x3_trans_i
  291. x3_local_i = self.local_conv2(x3_local_i) # 1024*14*14
  292. x3_local_list.append(x3_local_i)
  293. #============layer 4================
  294. #global branch
  295. x4 = self.global_layer4(x3_out) # 2048*7*7
  296. x4_attn, x4_theta = self.ha4(x4)
  297. x4_out = x4 * x4_attn
  298. # local branch
  299. if self.learn_region:
  300. x4_local_list = []
  301. for region_idx in range(4):
  302. x4_theta_i = x4_theta[:,region_idx,:]
  303. x4_theta_i = self.transform_theta(x4_theta_i, region_idx)
  304. x4_trans_i = self.stn(x4, x4_theta_i) #1024*14*14
  305. x4_trans_i = F.upsample(x4_trans_i, (14,14), mode='bilinear', align_corners=True)
  306. x4_local_i = x4_trans_i
  307. x4_local_i = self.local_conv3(x4_local_i) # 2048*7*7
  308. x4_local_list.append(x4_local_i)
  309. # ============== Feature generation ==============
  310. # global branch
  311. x4_avg = F.avg_pool2d(x4_out, x4_out.size()[2:]).view(x4_out.size(0), -1) #全局pooling 2048 之前已经relu过了
  312. if self.multi_scale:
  313. multi_scale_feature = torch.cat([x2_avg, x3_avg, x4_avg],1)
  314. global_fc = self.global_fc(multi_scale_feature)
  315. global_out = self.global_out(self.dropout(global_fc))
  316. else:
  317. global_out = self.global_out(x4_avg) # 2048 -> num_classes
  318. if self.learn_region:
  319. x_local_list = []
  320. local_512 = torch.randn(batch_size, 4, 512).cuda()
  321. local_1024 = torch.randn(batch_size, 4, 1024).cuda()
  322. local_2048 = torch.randn(batch_size, 4, 2048).cuda()
  323. for region_idx in range(4):
  324. x2_local_i = x2_local_list[region_idx]
  325. x2_local_i = F.avg_pool2d(x2_local_i, x2_local_i.size()[2:]).view(x2_local_i.size(0), -1) #每个local 都全局pooling
  326. local_512[:,region_idx] = x2_local_i
  327. x3_local_i = x3_local_list[region_idx]
  328. x3_local_i = F.avg_pool2d(x3_local_i, x3_local_i.size()[2:]).view(x3_local_i.size(0), -1) #每个local 都全局pooling
  329. local_1024[:,region_idx] = x3_local_i
  330. x4_local_i = x4_local_list[region_idx]
  331. x4_local_i = F.avg_pool2d(x4_local_i, x4_local_i.size()[2:]).view(x4_local_i.size(0), -1) #每个local 都全局pooling
  332. local_2048[:,region_idx] = x4_local_i
  333. local_512_maxpooing = local_512.max(1)[0]
  334. local_1024_maxpooing = local_1024.max(1)[0]
  335. local_2048_maxpooing = local_2048.max(1)[0]
  336. local_concate = torch.cat([local_512_maxpooing, local_1024_maxpooing, local_2048_maxpooing], 1)
  337. local_fc = self.local_fc(local_concate)
  338. local_out = self.classifier_local(self.dropout(local_fc))
  339. if self.multi_scale:
  340. out = torch.cat([global_fc,local_fc],1)
  341. else:
  342. out = torch.cat([x4_avg, local_512_maxpooing, local_1024_maxpooing, local_2048_maxpooing], 1) # global 和 local 一起做拼接 2048*2
  343. out = self.classifier_global(out)
  344. if self.learn_region:
  345. return out, global_out,local_out
  346. else:
  347. return global_out

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/362724?site
推荐阅读
相关标签
  

闽ICP备14008679号