为此,我们提出了首个医学图像分割框架TransUNet,该框架从序列到序列预测的角度建立了自注意机制。为了弥补transformer带来的特征分辨率损失,TransUNet采用了一种混合的CNN- transformer架构,利用来自CNN特征的详细高分辨率空间信息和transformer编码的全局上下文。受u型建筑设计的启发,Transformer Encoder的自关注特征图被上采样,与编码路径上shortcut的不同高分辨率CNN特征相结合,实现精确定位。我们证明,这样的设计允许我们的框架保持变压器的优点,也有利于医学图像分割。经验结果表明,与之前的基于cnn的自注意方法相比,我们基于transformer的体系结构提供了更好的利用自注意的方法。此外,我们观察到更密集的低层特征合并通常会导致更好的分割精度。大量的实验证明了该方法在各种医学图像分割任务中的优越性。

给定图像H×W ×C,空间分辨率为H×W,通道数为C。我们的目标是预测相应的大小为H × W的像素级标签映射。最常见的方法是直接训练CNN(例如UNet),首先将图像编码为高级特征表示,然后将其解码回全空间分辨率。与现有的方法不同,我们的方法通过使用变压器在编码器设计中引入了自我注意机制。




2)将三通道图像[B,3,224,224]]通过ResNetV2进行下采样,将图像编码为高级特征表示。然后创建一个feature列表将每个下采样后的特征图保存下来。最后经过ResNetV2网络输出[B,16C,14,14]。C=64 x widthfactor。feature列表包含三个尺寸大小的特征图,分别是[B,C,112,112],[B,4C,56,56],[B,8C,28,28]



利用卷积核大小为patch_size x patch_size,s=patch_size将图像切成一个个patch,得到[B,hidden_channel,num_patches],num_patches=(14/patch_size)**2。然后reshape为[B,n_p,h_c]。加入position embedding=[1,n_p,h_c]。最终得到输出特征图[B,n_p,h_c]。

4)Transformer Encoder。这步没什么说的,就是大家非常熟悉的Transformer,计算自注意力。输出[B,n_p,h_c]






  1. # coding=utf-8
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. import copy
  6. import logging
  7. import math
  8. from os.path import join as pjoin
  9. import torch
  10. import torch.nn as nn
  11. import numpy as np
  12. from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
  13. from torch.nn.modules.utils import _pair
  14. from scipy import ndimage
  15. from . import vit_seg_configs as configs
  16. from .vit_seg_modeling_resnet_skip import ResNetV2
  17. logger = logging.getLogger(__name__)
  18. ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
  19. ATTENTION_K = "MultiHeadDotProductAttention_1/key"
  20. ATTENTION_V = "MultiHeadDotProductAttention_1/value"
  21. ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
  22. FC_0 = "MlpBlock_3/Dense_0"
  23. FC_1 = "MlpBlock_3/Dense_1"
  24. ATTENTION_NORM = "LayerNorm_0"
  25. MLP_NORM = "LayerNorm_2"
  26. def np2th(weights, conv=False):
  27. """Possibly convert HWIO to OIHW."""
  28. if conv:
  29. weights = weights.transpose([3, 2, 0, 1])
  30. return torch.from_numpy(weights)
  31. # sigmoid激活函数
  32. def swish(x):
  33. return x * torch.sigmoid(x)
  34. # 激活函数的选择
  35. ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
  36. class Attention(nn.Module):
  37. def __init__(self, config, vis):
  38. super(Attention, self).__init__()
  39. self.vis = vis
  40. self.num_attention_heads = config.transformer["num_heads"]
  41. self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
  42. self.all_head_size = self.num_attention_heads * self.attention_head_size
  43. # 线性层得到q、k、v向量
  44. self.query = Linear(config.hidden_size, self.all_head_size)
  45. self.key = Linear(config.hidden_size, self.all_head_size)
  46. self.value = Linear(config.hidden_size, self.all_head_size)
  47. self.out = Linear(config.hidden_size, config.hidden_size)
  48. self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
  49. self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
  50. self.softmax = Softmax(dim=-1)
  51. def transpose_for_scores(self, x):
  52. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  53. x = x.view(*new_x_shape)
  54. return x.permute(0, 2, 1, 3)
  55. def forward(self, hidden_states):
  56. # 线性层得到q、k、v矩阵
  57. mixed_query_layer = self.query(hidden_states)
  58. mixed_key_layer = self.key(hidden_states)
  59. mixed_value_layer = self.value(hidden_states)
  60. query_layer = self.transpose_for_scores(mixed_query_layer)
  61. key_layer = self.transpose_for_scores(mixed_key_layer)
  62. value_layer = self.transpose_for_scores(mixed_value_layer)
  63. # q乘k的转置
  64. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  65. # 除以根号下d
  66. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  67. # 进行softmax
  68. attention_probs = self.softmax(attention_scores)
  69. weights = attention_probs if self.vis else None
  70. attention_probs = self.attn_dropout(attention_probs)
  71. # 乘以v矩阵
  72. context_layer = torch.matmul(attention_probs, value_layer)
  73. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  74. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  75. context_layer = context_layer.view(*new_context_layer_shape)
  76. # 映射层
  77. attention_output = self.out(context_layer)
  78. attention_output = self.proj_dropout(attention_output)
  79. return attention_output, weights
  80. class Mlp(nn.Module):
  81. def __init__(self, config):
  82. super(Mlp, self).__init__()
  83. self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
  84. self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
  85. self.act_fn = ACT2FN["gelu"]
  86. self.dropout = Dropout(config.transformer["dropout_rate"])
  87. self._init_weights()
  88. def _init_weights(self):
  89. nn.init.xavier_uniform_(self.fc1.weight)
  90. nn.init.xavier_uniform_(self.fc2.weight)
  91. nn.init.normal_(self.fc1.bias, std=1e-6)
  92. nn.init.normal_(self.fc2.bias, std=1e-6)
  93. def forward(self, x):
  94. x = self.fc1(x)
  95. x = self.act_fn(x)
  96. x = self.dropout(x)
  97. x = self.fc2(x)
  98. x = self.dropout(x)
  99. return x
  100. class Embeddings(nn.Module):
  101. """Construct the embeddings from patch, position embeddings.
  102. """
  103. def __init__(self, config, img_size, in_channels=3):
  104. super(Embeddings, self).__init__()
  105. self.hybrid = None
  106. self.config = config
  107. img_size = _pair(img_size)
  108. if config.patches.get("grid") is not None: # ResNet
  109. grid_size = config.patches["grid"]
  110. patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
  111. patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
  112. n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
  113. self.hybrid = True
  114. else:
  115. patch_size = _pair(config.patches["size"])
  116. n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
  117. self.hybrid = False
  118. if self.hybrid:
  119. self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
  120. in_channels = self.hybrid_model.width * 16
  121. self.patch_embeddings = Conv2d(in_channels=in_channels,
  122. out_channels=config.hidden_size,
  123. kernel_size=patch_size,
  124. stride=patch_size)
  125. self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
  126. self.dropout = Dropout(config.transformer["dropout_rate"])
  127. def forward(self, x):
  128. if self.hybrid:
  129. x, features = self.hybrid_model(x)
  130. else:
  131. features = None
  132. # [B,16C,14,14]
  133. x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
  134. x = x.flatten(2) # (B, hidden, n_patches)
  135. x = x.transpose(-1, -2) # (B, n_patches, hidden)
  136. embeddings = x + self.position_embeddings
  137. embeddings = self.dropout(embeddings)
  138. return embeddings, features
  139. class Block(nn.Module):
  140. def __init__(self, config, vis):
  141. super(Block, self).__init__()
  142. self.hidden_size = config.hidden_size
  143. self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
  144. self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
  145. self.ffn = Mlp(config)
  146. self.attn = Attention(config, vis)
  147. def forward(self, x):
  148. h = x
  149. x = self.attention_norm(x)
  150. x, weights = self.attn(x)
  151. x = x + h
  152. h = x
  153. x = self.ffn_norm(x)
  154. x = self.ffn(x)
  155. x = x + h
  156. return x, weights
  157. def load_from(self, weights, n_block):
  158. ROOT = f"Transformer/encoderblock_{n_block}"
  159. with torch.no_grad():
  160. query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  161. key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  162. value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  163. out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  164. query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
  165. key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
  166. value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
  167. out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
  168. self.attn.query.weight.copy_(query_weight)
  169. self.attn.key.weight.copy_(key_weight)
  170. self.attn.value.weight.copy_(value_weight)
  171. self.attn.out.weight.copy_(out_weight)
  172. self.attn.query.bias.copy_(query_bias)
  173. self.attn.key.bias.copy_(key_bias)
  174. self.attn.value.bias.copy_(value_bias)
  175. self.attn.out.bias.copy_(out_bias)
  176. mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
  177. mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
  178. mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
  179. mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
  180. self.ffn.fc1.weight.copy_(mlp_weight_0)
  181. self.ffn.fc2.weight.copy_(mlp_weight_1)
  182. self.ffn.fc1.bias.copy_(mlp_bias_0)
  183. self.ffn.fc2.bias.copy_(mlp_bias_1)
  184. self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
  185. self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
  186. self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
  187. self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
  188. class Encoder(nn.Module):
  189. def __init__(self, config, vis):
  190. super(Encoder, self).__init__()
  191. self.vis = vis
  192. self.layer = nn.ModuleList()
  193. self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
  194. for _ in range(config.transformer["num_layers"]):
  195. layer = Block(config, vis)
  196. self.layer.append(copy.deepcopy(layer))
  197. def forward(self, hidden_states):
  198. attn_weights = []
  199. # 依次经过几个Block
  200. for layer_block in self.layer:
  201. hidden_states, weights = layer_block(hidden_states)
  202. if self.vis:
  203. attn_weights.append(weights)
  204. encoded = self.encoder_norm(hidden_states)
  205. return encoded, attn_weights
  206. class Transformer(nn.Module):
  207. def __init__(self, config, img_size, vis):
  208. super(Transformer, self).__init__()
  209. self.embeddings = Embeddings(config, img_size=img_size)
  210. self.encoder = Encoder(config, vis)
  211. def forward(self, input_ids):
  212. embedding_output, features = self.embeddings(input_ids)
  213. encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
  214. return encoded, attn_weights, features
  215. class Conv2dReLU(nn.Sequential):
  216. def __init__(
  217. self,
  218. in_channels,
  219. out_channels,
  220. kernel_size,
  221. padding=0,
  222. stride=1,
  223. use_batchnorm=True,
  224. ):
  225. conv = nn.Conv2d(
  226. in_channels,
  227. out_channels,
  228. kernel_size,
  229. stride=stride,
  230. padding=padding,
  231. bias=not (use_batchnorm),
  232. )
  233. relu = nn.ReLU(inplace=True)
  234. bn = nn.BatchNorm2d(out_channels)
  235. super(Conv2dReLU, self).__init__(conv, bn, relu)
  236. class DecoderBlock(nn.Module):
  237. def __init__(
  238. self,
  239. in_channels,
  240. out_channels,
  241. skip_channels=0,
  242. use_batchnorm=True,
  243. ):
  244. super().__init__()
  245. self.conv1 = Conv2dReLU(
  246. in_channels + skip_channels,
  247. out_channels,
  248. kernel_size=3,
  249. padding=1,
  250. use_batchnorm=use_batchnorm,
  251. )
  252. self.conv2 = Conv2dReLU(
  253. out_channels,
  254. out_channels,
  255. kernel_size=3,
  256. padding=1,
  257. use_batchnorm=use_batchnorm,
  258. )
  259. self.up = nn.UpsamplingBilinear2d(scale_factor=2)
  260. def forward(self, x, skip=None):
  261. x = self.up(x)
  262. if skip is not None:
  263. x = torch.cat([x, skip], dim=1)
  264. x = self.conv1(x)
  265. x = self.conv2(x)
  266. return x
  267. class SegmentationHead(nn.Sequential):
  268. def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
  269. conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
  270. upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
  271. super().__init__(conv2d, upsampling)
  272. class DecoderCup(nn.Module):
  273. def __init__(self, config):
  274. super().__init__()
  275. self.config = config
  276. head_channels = 512
  277. self.conv_more = Conv2dReLU(
  278. config.hidden_size,
  279. head_channels,
  280. kernel_size=3,
  281. padding=1,
  282. use_batchnorm=True,
  283. )
  284. decoder_channels = config.decoder_channels
  285. in_channels = [head_channels] + list(decoder_channels[:-1])
  286. out_channels = decoder_channels
  287. if self.config.n_skip != 0:
  288. skip_channels = self.config.skip_channels
  289. for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
  290. skip_channels[3-i]=0
  291. else:
  292. skip_channels=[0,0,0,0]
  293. blocks = [
  294. DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
  295. ]
  296. self.blocks = nn.ModuleList(blocks)
  297. def forward(self, hidden_states, features=None):
  298. B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
  299. h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
  300. x = hidden_states.permute(0, 2, 1)
  301. x = x.contiguous().view(B, hidden, h, w)
  302. x = self.conv_more(x)
  303. for i, decoder_block in enumerate(self.blocks):
  304. if features is not None:
  305. skip = features[i] if (i < self.config.n_skip) else None
  306. else:
  307. skip = None
  308. x = decoder_block(x, skip=skip)
  309. return x
  310. class VisionTransformer(nn.Module):
  311. def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
  312. super(VisionTransformer, self).__init__()
  313. self.num_classes = num_classes
  314. self.zero_head = zero_head
  315. self.classifier = config.classifier
  316. self.transformer = Transformer(config, img_size, vis)
  317. self.decoder = DecoderCup(config)
  318. self.segmentation_head = SegmentationHead(
  319. in_channels=config['decoder_channels'][-1],
  320. out_channels=config['n_classes'],
  321. kernel_size=3,
  322. )
  323. self.config = config
  324. def forward(self, x):
  325. # [B,C,H,W],当图片为单通道[B,1,H,W]时,将通道数复制三次得到[B,3,H,W]
  326. if x.size()[1] == 1:
  327. # repeat的参数是对应维度的复制个数,参数为(2,2)时,0维复制两次,1维复制两次
  328. x = x.repeat(1,3,1,1)
  329. x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
  330. x = self.decoder(x, features) # [B,16,H,W]
  331. logits = self.segmentation_head(x)
  332. return logits
  333. def load_from(self, weights):
  334. with torch.no_grad():
  335. res_weight = weights
  336. self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
  337. self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
  338. self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
  339. self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
  340. posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
  341. posemb_new = self.transformer.embeddings.position_embeddings
  342. if posemb.size() == posemb_new.size():
  343. self.transformer.embeddings.position_embeddings.copy_(posemb)
  344. elif posemb.size()[1]-1 == posemb_new.size()[1]:
  345. posemb = posemb[:, 1:]
  346. self.transformer.embeddings.position_embeddings.copy_(posemb)
  347. else:
  348. logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
  349. ntok_new = posemb_new.size(1)
  350. if self.classifier == "seg":
  351. _, posemb_grid = posemb[:, :1], posemb[0, 1:]
  352. gs_old = int(np.sqrt(len(posemb_grid)))
  353. gs_new = int(np.sqrt(ntok_new))
  354. print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
  355. posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
  356. zoom = (gs_new / gs_old, gs_new / gs_old, 1)
  357. posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
  358. posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
  359. posemb = posemb_grid
  360. self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
  361. # Encoder whole
  362. for bname, block in self.transformer.encoder.named_children():
  363. for uname, unit in block.named_children():
  364. unit.load_from(weights, n_block=uname)
  365. if self.transformer.embeddings.hybrid:
  366. self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
  367. gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
  368. gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
  369. self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
  370. self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
  371. for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
  372. for uname, unit in block.named_children():
  373. unit.load_from(res_weight, n_block=bname, n_unit=uname)
  374. CONFIGS = {
  375. 'ViT-B_16': configs.get_b16_config(),
  376. 'ViT-B_32': configs.get_b32_config(),
  377. 'ViT-L_16': configs.get_l16_config(),
  378. 'ViT-L_32': configs.get_l32_config(),
  379. 'ViT-H_14': configs.get_h14_config(),
  380. 'R50-ViT-B_16': configs.get_r50_b16_config(),
  381. 'R50-ViT-L_16': configs.get_r50_l16_config(),
  382. 'testing': configs.get_testing(),
  383. }


  1. import math
  2. from os.path import join as pjoin
  3. from collections import OrderedDict
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. def np2th(weights, conv=False):
  8. """Possibly convert HWIO to OIHW."""
  9. if conv:
  10. weights = weights.transpose([3, 2, 0, 1])
  11. return torch.from_numpy(weights)
  12. class StdConv2d(nn.Conv2d):
  13. def forward(self, x):
  14. w = self.weight
  15. v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
  16. w = (w - m) / torch.sqrt(v + 1e-5)
  17. return F.conv2d(x, w, self.bias, self.stride, self.padding,
  18. self.dilation, self.groups)
  19. def conv3x3(cin, cout, stride=1, groups=1, bias=False):
  20. return StdConv2d(cin, cout, kernel_size=3, stride=stride,
  21. padding=1, bias=bias, groups=groups)
  22. def conv1x1(cin, cout, stride=1, bias=False):
  23. return StdConv2d(cin, cout, kernel_size=1, stride=stride,
  24. padding=0, bias=bias)
  25. class PreActBottleneck(nn.Module):
  26. """Pre-activation (v2) bottleneck block.
  27. """
  28. def __init__(self, cin, cout=None, cmid=None, stride=1):
  29. super().__init__()
  30. cout = cout or cin
  31. cmid = cmid or cout//4
  32. self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
  33. self.conv1 = conv1x1(cin, cmid, bias=False)
  34. self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
  35. self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
  36. self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
  37. self.conv3 = conv1x1(cmid, cout, bias=False)
  38. self.relu = nn.ReLU(inplace=True)
  39. if (stride != 1 or cin != cout):
  40. # Projection also with pre-activation according to paper.
  41. self.downsample = conv1x1(cin, cout, stride, bias=False)
  42. self.gn_proj = nn.GroupNorm(cout, cout)
  43. def forward(self, x):
  44. # Residual branch
  45. residual = x
  46. # 是否有下采样模块
  47. if hasattr(self, 'downsample'):
  48. residual = self.downsample(x)
  49. residual = self.gn_proj(residual)
  50. # Unit's branch
  51. y = self.relu(self.gn1(self.conv1(x)))
  52. y = self.relu(self.gn2(self.conv2(y)))
  53. y = self.gn3(self.conv3(y))
  54. y = self.relu(residual + y)
  55. return y
  56. def load_from(self, weights, n_block, n_unit):
  57. conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
  58. conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
  59. conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
  60. gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
  61. gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
  62. gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
  63. gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
  64. gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
  65. gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
  66. self.conv1.weight.copy_(conv1_weight)
  67. self.conv2.weight.copy_(conv2_weight)
  68. self.conv3.weight.copy_(conv3_weight)
  69. self.gn1.weight.copy_(gn1_weight.view(-1))
  70. self.gn1.bias.copy_(gn1_bias.view(-1))
  71. self.gn2.weight.copy_(gn2_weight.view(-1))
  72. self.gn2.bias.copy_(gn2_bias.view(-1))
  73. self.gn3.weight.copy_(gn3_weight.view(-1))
  74. self.gn3.bias.copy_(gn3_bias.view(-1))
  75. if hasattr(self, 'downsample'):
  76. proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
  77. proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
  78. proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
  79. self.downsample.weight.copy_(proj_conv_weight)
  80. self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
  81. self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
  82. class ResNetV2(nn.Module):
  83. """Implementation of Pre-activation (v2) ResNet mode."""
  84. def __init__(self, block_units, width_factor):
  85. super().__init__()
  86. width = int(64 * width_factor)
  87. self.width = width
  88. self.root = nn.Sequential(OrderedDict([
  89. ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
  90. ('gn', nn.GroupNorm(32, width, eps=1e-6)),
  91. ('relu', nn.ReLU(inplace=True)),
  92. # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
  93. ]))
  94. # [B,C,55,55]
  95. self.body = nn.Sequential(OrderedDict([
  96. ('block1', nn.Sequential(OrderedDict(
  97. [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
  98. [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
  99. ))),
  100. ('block2', nn.Sequential(OrderedDict(
  101. [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
  102. [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
  103. ))),
  104. ('block3', nn.Sequential(OrderedDict(
  105. [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
  106. [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
  107. ))),
  108. ]))
  109. def forward(self, x):
  110. # [B,3,224,224]
  111. features = []
  112. b, c, in_size, _ = x.size()
  113. x = self.root(x) # [B,C,112,112]
  114. features.append(x)
  115. x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) # [B,C,55,55]
  116. for i in range(len(self.body)-1):
  117. x = self.body[i](x)
  118. right_size = int(in_size / 4 / (i+1))
  119. if x.size()[2] != right_size:
  120. pad = right_size - x.size()[2]
  121. assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
  122. feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
  123. feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
  124. else:
  125. feat = x
  126. features.append(feat)
  127. x = self.body[-1](x) # [B,16C,14,14]
  128. return x, features[::-1]

