当前位置:   article > 正文

TransUNet: Transformers Make StrongEncoders for Medical Image Segmentation文章详解(结合代码)_transunet代码详解

transunet代码详解

1.摘要

医学图像分割是开发医疗保健系统,特别是疾病诊断和治疗计划的必要前提。在各种医学图像分割任务中,U形架构(也称为U-Net)已成为事实上的标准,并取得了巨大的成功。然而,由于卷积运算的内在局部性,U-Net在显式建模长程依赖性方面通常表现出局限性。Transformer是为序列间预测而设计的,它已经成为具有天生的全局自我关注机制的替代架构,但由于低级细节不足,定位能力有限。

在本文中,我们提出TransUNet作为医学图像分割的有力替代方案,它既有Transformers的优点,也有U-Net的优点。一方面,Transformer将来自卷积神经网络(CNN)特征图的标记化图像块编码为用于提取全局上下文的输入序列。另一方面,解码器对编码特征进行上采样,然后将其与高分辨率CNN特征图相结合,以实现精确定位。

我们认为Transformer可以作为医学图像分割任务的强大编码器,与U-Net相结合,通过恢复局部空间信息来增强更精细的细节。TransUNet在不同的医疗应用(包括多器官分割和心脏分割)上实现了优于各种竞争方法的性能。

2.介绍

尽管基于CNN的方法具有非凡的表示能力,但由于卷积运算的内在局部性,其在建模显式长程关系时通常表现出局限性。因此,这些结构通常产生较弱的性能,特别是对于患者之间在纹理、形状和大小方面存在较大差异的目标结构。为了克服这一限制,现有研究建议基于CNN特征建立自我注意机制。另一方面,专为序列预测而设计的Transformers已成为替代架构,它完全使用分配卷积算子,并且仅依赖于注意机制而不是Transformer。与先前基于CNN的方法不同,Transformers不仅在建模全局上下文方面具有强大的能力,而且在大规模预训练下对下游任务表现出卓越的可转移性。这一成功在机器翻译和自然语言处理(NLP)领域得到了广泛证实。最近,各种图像识别任务的尝试也达到甚至超过了最先进的水平。

在这篇论文中,提出了第一个研究,它探索了Transformer在医学图像分割背景下的潜力。然而,有趣的是,我们发现单纯的使用(即使用一个Transformer对标记化的图像补丁进行编码,然后直接对隐藏特征表示进行上采样,使其成为一个全分辨率的密集输出)不能产生令人满意的结果。

这是由于transformer将输入视为1D序列,专注于各个阶段的全局上下文建模,因此导致低分辨率的特征缺乏详细的定位信息。而这些信息不能通过直接上采样到全分辨率有效恢复,因此导致分割结果粗糙。另一方面,CNN体系结构(例如U-Net)提供了一种提取低级视觉线索的途径,可以很好地弥补这些细微的空间细节。

为此,我们提出了首个医学图像分割框架TransUNet,该框架从序列到序列预测的角度建立了自注意机制。为了弥补transformer带来的特征分辨率损失,TransUNet采用了一种混合的CNN- transformer架构,利用来自CNN特征的详细高分辨率空间信息和transformer编码的全局上下文。受u型建筑设计的启发,Transformer Encoder的自关注特征图被上采样,与编码路径上shortcut的不同高分辨率CNN特征相结合,实现精确定位。我们证明,这样的设计允许我们的框架保持变压器的优点,也有利于医学图像分割。经验结果表明,与之前的基于cnn的自注意方法相比,我们基于transformer的体系结构提供了更好的利用自注意的方法。此外,我们观察到更密集的低层特征合并通常会导致更好的分割精度。大量的实验证明了该方法在各种医学图像分割任务中的优越性。

给定图像H×W ×C,空间分辨率为H×W,通道数为C。我们的目标是预测相应的大小为H × W的像素级标签映射。最常见的方法是直接训练CNN(例如UNet),首先将图像编码为高级特征表示,然后将其解码回全空间分辨率。与现有的方法不同,我们的方法通过使用变压器在编码器设计中引入了自我注意机制。

3.网络结构

整个TransUNet的网络结构如图上所示,其实挺容易理解的。

1)首先输入一个CT图像,如果图像是单通道的就通过repeat函数复制两次,将通道扩充为三通道的。

2)将三通道图像[B,3,224,224]]通过ResNetV2进行下采样,将图像编码为高级特征表示。然后创建一个feature列表将每个下采样后的特征图保存下来。最后经过ResNetV2网络输出[B,16C,14,14]。C=64 x widthfactor。feature列表包含三个尺寸大小的特征图,分别是[B,C,112,112],[B,4C,56,56],[B,8C,28,28]

具体ResNetV2的结构:先使用卷积核7x7,s=2,p=3的卷积操作进行root下采样为[B,C,112,112],然后通过maxpooling层下采样为[B,C,56,56]。最后就通过三个block下采样,输出特征图。

3)先对下采样后的特征图[B,16C,14,14]进行embedding

利用卷积核大小为patch_size x patch_size,s=patch_size将图像切成一个个patch,得到[B,hidden_channel,num_patches],num_patches=(14/patch_size)**2。然后reshape为[B,n_p,h_c]。加入position embedding=[1,n_p,h_c]。最终得到输出特征图[B,n_p,h_c]。

4)Transformer Encoder。这步没什么说的,就是大家非常熟悉的Transformer,计算自注意力。输出[B,n_p,h_c]

5)Decoder。先将[B,n_p,h_c]reshape为[B,h_c,根号下n_p,根号下n_p]。然后Conv2d为[B,512,14,14]。然后通过四个decoder最后输出[B,32,224,224]

decoder为先通过双线性上采样来使特征图尺寸扩大一倍,然后与之前卷积后shortcut的feature进行concat,最后通过两个Conv2d将特征图映射到低维空间。

6)SegmentationHead。将网络输出的特征图进行分割,通过Conv2d输出[B,n_classes,224,224]

4.TransUNet网络全部代码

下方代码为除了ResNetV2之外的全部代码。

  1. # coding=utf-8
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. import copy
  6. import logging
  7. import math
  8. from os.path import join as pjoin
  9. import torch
  10. import torch.nn as nn
  11. import numpy as np
  12. from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
  13. from torch.nn.modules.utils import _pair
  14. from scipy import ndimage
  15. from . import vit_seg_configs as configs
  16. from .vit_seg_modeling_resnet_skip import ResNetV2
  17. logger = logging.getLogger(__name__)
  18. ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
  19. ATTENTION_K = "MultiHeadDotProductAttention_1/key"
  20. ATTENTION_V = "MultiHeadDotProductAttention_1/value"
  21. ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
  22. FC_0 = "MlpBlock_3/Dense_0"
  23. FC_1 = "MlpBlock_3/Dense_1"
  24. ATTENTION_NORM = "LayerNorm_0"
  25. MLP_NORM = "LayerNorm_2"
  26. def np2th(weights, conv=False):
  27. """Possibly convert HWIO to OIHW."""
  28. if conv:
  29. weights = weights.transpose([3, 2, 0, 1])
  30. return torch.from_numpy(weights)
  31. # sigmoid激活函数
  32. def swish(x):
  33. return x * torch.sigmoid(x)
  34. # 激活函数的选择
  35. ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
  36. class Attention(nn.Module):
  37. def __init__(self, config, vis):
  38. super(Attention, self).__init__()
  39. self.vis = vis
  40. self.num_attention_heads = config.transformer["num_heads"]
  41. self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
  42. self.all_head_size = self.num_attention_heads * self.attention_head_size
  43. # 线性层得到q、k、v向量
  44. self.query = Linear(config.hidden_size, self.all_head_size)
  45. self.key = Linear(config.hidden_size, self.all_head_size)
  46. self.value = Linear(config.hidden_size, self.all_head_size)
  47. self.out = Linear(config.hidden_size, config.hidden_size)
  48. self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
  49. self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
  50. self.softmax = Softmax(dim=-1)
  51. def transpose_for_scores(self, x):
  52. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  53. x = x.view(*new_x_shape)
  54. return x.permute(0, 2, 1, 3)
  55. def forward(self, hidden_states):
  56. # 线性层得到q、k、v矩阵
  57. mixed_query_layer = self.query(hidden_states)
  58. mixed_key_layer = self.key(hidden_states)
  59. mixed_value_layer = self.value(hidden_states)
  60. query_layer = self.transpose_for_scores(mixed_query_layer)
  61. key_layer = self.transpose_for_scores(mixed_key_layer)
  62. value_layer = self.transpose_for_scores(mixed_value_layer)
  63. # q乘k的转置
  64. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  65. # 除以根号下d
  66. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  67. # 进行softmax
  68. attention_probs = self.softmax(attention_scores)
  69. weights = attention_probs if self.vis else None
  70. attention_probs = self.attn_dropout(attention_probs)
  71. # 乘以v矩阵
  72. context_layer = torch.matmul(attention_probs, value_layer)
  73. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  74. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  75. context_layer = context_layer.view(*new_context_layer_shape)
  76. # 映射层
  77. attention_output = self.out(context_layer)
  78. attention_output = self.proj_dropout(attention_output)
  79. return attention_output, weights
  80. class Mlp(nn.Module):
  81. def __init__(self, config):
  82. super(Mlp, self).__init__()
  83. self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
  84. self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
  85. self.act_fn = ACT2FN["gelu"]
  86. self.dropout = Dropout(config.transformer["dropout_rate"])
  87. self._init_weights()
  88. def _init_weights(self):
  89. nn.init.xavier_uniform_(self.fc1.weight)
  90. nn.init.xavier_uniform_(self.fc2.weight)
  91. nn.init.normal_(self.fc1.bias, std=1e-6)
  92. nn.init.normal_(self.fc2.bias, std=1e-6)
  93. def forward(self, x):
  94. x = self.fc1(x)
  95. x = self.act_fn(x)
  96. x = self.dropout(x)
  97. x = self.fc2(x)
  98. x = self.dropout(x)
  99. return x
  100. class Embeddings(nn.Module):
  101. """Construct the embeddings from patch, position embeddings.
  102. """
  103. def __init__(self, config, img_size, in_channels=3):
  104. super(Embeddings, self).__init__()
  105. self.hybrid = None
  106. self.config = config
  107. img_size = _pair(img_size)
  108. if config.patches.get("grid") is not None: # ResNet
  109. grid_size = config.patches["grid"]
  110. patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
  111. patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
  112. n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
  113. self.hybrid = True
  114. else:
  115. patch_size = _pair(config.patches["size"])
  116. n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
  117. self.hybrid = False
  118. if self.hybrid:
  119. self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
  120. in_channels = self.hybrid_model.width * 16
  121. self.patch_embeddings = Conv2d(in_channels=in_channels,
  122. out_channels=config.hidden_size,
  123. kernel_size=patch_size,
  124. stride=patch_size)
  125. self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
  126. self.dropout = Dropout(config.transformer["dropout_rate"])
  127. def forward(self, x):
  128. if self.hybrid:
  129. x, features = self.hybrid_model(x)
  130. else:
  131. features = None
  132. # [B,16C,14,14]
  133. x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
  134. x = x.flatten(2) # (B, hidden, n_patches)
  135. x = x.transpose(-1, -2) # (B, n_patches, hidden)
  136. embeddings = x + self.position_embeddings
  137. embeddings = self.dropout(embeddings)
  138. return embeddings, features
  139. class Block(nn.Module):
  140. def __init__(self, config, vis):
  141. super(Block, self).__init__()
  142. self.hidden_size = config.hidden_size
  143. self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
  144. self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
  145. self.ffn = Mlp(config)
  146. self.attn = Attention(config, vis)
  147. def forward(self, x):
  148. h = x
  149. x = self.attention_norm(x)
  150. x, weights = self.attn(x)
  151. x = x + h
  152. h = x
  153. x = self.ffn_norm(x)
  154. x = self.ffn(x)
  155. x = x + h
  156. return x, weights
  157. def load_from(self, weights, n_block):
  158. ROOT = f"Transformer/encoderblock_{n_block}"
  159. with torch.no_grad():
  160. query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  161. key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  162. value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  163. out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
  164. query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
  165. key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
  166. value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
  167. out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
  168. self.attn.query.weight.copy_(query_weight)
  169. self.attn.key.weight.copy_(key_weight)
  170. self.attn.value.weight.copy_(value_weight)
  171. self.attn.out.weight.copy_(out_weight)
  172. self.attn.query.bias.copy_(query_bias)
  173. self.attn.key.bias.copy_(key_bias)
  174. self.attn.value.bias.copy_(value_bias)
  175. self.attn.out.bias.copy_(out_bias)
  176. mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
  177. mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
  178. mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
  179. mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
  180. self.ffn.fc1.weight.copy_(mlp_weight_0)
  181. self.ffn.fc2.weight.copy_(mlp_weight_1)
  182. self.ffn.fc1.bias.copy_(mlp_bias_0)
  183. self.ffn.fc2.bias.copy_(mlp_bias_1)
  184. self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
  185. self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
  186. self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
  187. self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
  188. class Encoder(nn.Module):
  189. def __init__(self, config, vis):
  190. super(Encoder, self).__init__()
  191. self.vis = vis
  192. self.layer = nn.ModuleList()
  193. self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
  194. for _ in range(config.transformer["num_layers"]):
  195. layer = Block(config, vis)
  196. self.layer.append(copy.deepcopy(layer))
  197. def forward(self, hidden_states):
  198. attn_weights = []
  199. # 依次经过几个Block
  200. for layer_block in self.layer:
  201. hidden_states, weights = layer_block(hidden_states)
  202. if self.vis:
  203. attn_weights.append(weights)
  204. encoded = self.encoder_norm(hidden_states)
  205. return encoded, attn_weights
  206. class Transformer(nn.Module):
  207. def __init__(self, config, img_size, vis):
  208. super(Transformer, self).__init__()
  209. self.embeddings = Embeddings(config, img_size=img_size)
  210. self.encoder = Encoder(config, vis)
  211. def forward(self, input_ids):
  212. embedding_output, features = self.embeddings(input_ids)
  213. encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
  214. return encoded, attn_weights, features
  215. class Conv2dReLU(nn.Sequential):
  216. def __init__(
  217. self,
  218. in_channels,
  219. out_channels,
  220. kernel_size,
  221. padding=0,
  222. stride=1,
  223. use_batchnorm=True,
  224. ):
  225. conv = nn.Conv2d(
  226. in_channels,
  227. out_channels,
  228. kernel_size,
  229. stride=stride,
  230. padding=padding,
  231. bias=not (use_batchnorm),
  232. )
  233. relu = nn.ReLU(inplace=True)
  234. bn = nn.BatchNorm2d(out_channels)
  235. super(Conv2dReLU, self).__init__(conv, bn, relu)
  236. class DecoderBlock(nn.Module):
  237. def __init__(
  238. self,
  239. in_channels,
  240. out_channels,
  241. skip_channels=0,
  242. use_batchnorm=True,
  243. ):
  244. super().__init__()
  245. self.conv1 = Conv2dReLU(
  246. in_channels + skip_channels,
  247. out_channels,
  248. kernel_size=3,
  249. padding=1,
  250. use_batchnorm=use_batchnorm,
  251. )
  252. self.conv2 = Conv2dReLU(
  253. out_channels,
  254. out_channels,
  255. kernel_size=3,
  256. padding=1,
  257. use_batchnorm=use_batchnorm,
  258. )
  259. self.up = nn.UpsamplingBilinear2d(scale_factor=2)
  260. def forward(self, x, skip=None):
  261. x = self.up(x)
  262. if skip is not None:
  263. x = torch.cat([x, skip], dim=1)
  264. x = self.conv1(x)
  265. x = self.conv2(x)
  266. return x
  267. class SegmentationHead(nn.Sequential):
  268. def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
  269. conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
  270. upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
  271. super().__init__(conv2d, upsampling)
  272. class DecoderCup(nn.Module):
  273. def __init__(self, config):
  274. super().__init__()
  275. self.config = config
  276. head_channels = 512
  277. self.conv_more = Conv2dReLU(
  278. config.hidden_size,
  279. head_channels,
  280. kernel_size=3,
  281. padding=1,
  282. use_batchnorm=True,
  283. )
  284. decoder_channels = config.decoder_channels
  285. in_channels = [head_channels] + list(decoder_channels[:-1])
  286. out_channels = decoder_channels
  287. if self.config.n_skip != 0:
  288. skip_channels = self.config.skip_channels
  289. for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
  290. skip_channels[3-i]=0
  291. else:
  292. skip_channels=[0,0,0,0]
  293. blocks = [
  294. DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
  295. ]
  296. self.blocks = nn.ModuleList(blocks)
  297. def forward(self, hidden_states, features=None):
  298. B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
  299. h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
  300. x = hidden_states.permute(0, 2, 1)
  301. x = x.contiguous().view(B, hidden, h, w)
  302. x = self.conv_more(x)
  303. for i, decoder_block in enumerate(self.blocks):
  304. if features is not None:
  305. skip = features[i] if (i < self.config.n_skip) else None
  306. else:
  307. skip = None
  308. x = decoder_block(x, skip=skip)
  309. return x
  310. class VisionTransformer(nn.Module):
  311. def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
  312. super(VisionTransformer, self).__init__()
  313. self.num_classes = num_classes
  314. self.zero_head = zero_head
  315. self.classifier = config.classifier
  316. self.transformer = Transformer(config, img_size, vis)
  317. self.decoder = DecoderCup(config)
  318. self.segmentation_head = SegmentationHead(
  319. in_channels=config['decoder_channels'][-1],
  320. out_channels=config['n_classes'],
  321. kernel_size=3,
  322. )
  323. self.config = config
  324. def forward(self, x):
  325. # [B,C,H,W],当图片为单通道[B,1,H,W]时,将通道数复制三次得到[B,3,H,W]
  326. if x.size()[1] == 1:
  327. # repeat的参数是对应维度的复制个数,参数为(2,2)时,0维复制两次,1维复制两次
  328. x = x.repeat(1,3,1,1)
  329. x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
  330. x = self.decoder(x, features) # [B,16,H,W]
  331. logits = self.segmentation_head(x)
  332. return logits
  333. def load_from(self, weights):
  334. with torch.no_grad():
  335. res_weight = weights
  336. self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
  337. self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
  338. self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
  339. self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))
  340. posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
  341. posemb_new = self.transformer.embeddings.position_embeddings
  342. if posemb.size() == posemb_new.size():
  343. self.transformer.embeddings.position_embeddings.copy_(posemb)
  344. elif posemb.size()[1]-1 == posemb_new.size()[1]:
  345. posemb = posemb[:, 1:]
  346. self.transformer.embeddings.position_embeddings.copy_(posemb)
  347. else:
  348. logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
  349. ntok_new = posemb_new.size(1)
  350. if self.classifier == "seg":
  351. _, posemb_grid = posemb[:, :1], posemb[0, 1:]
  352. gs_old = int(np.sqrt(len(posemb_grid)))
  353. gs_new = int(np.sqrt(ntok_new))
  354. print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
  355. posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
  356. zoom = (gs_new / gs_old, gs_new / gs_old, 1)
  357. posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1) # th2np
  358. posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
  359. posemb = posemb_grid
  360. self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))
  361. # Encoder whole
  362. for bname, block in self.transformer.encoder.named_children():
  363. for uname, unit in block.named_children():
  364. unit.load_from(weights, n_block=uname)
  365. if self.transformer.embeddings.hybrid:
  366. self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))
  367. gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
  368. gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
  369. self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
  370. self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)
  371. for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
  372. for uname, unit in block.named_children():
  373. unit.load_from(res_weight, n_block=bname, n_unit=uname)
  374. CONFIGS = {
  375. 'ViT-B_16': configs.get_b16_config(),
  376. 'ViT-B_32': configs.get_b32_config(),
  377. 'ViT-L_16': configs.get_l16_config(),
  378. 'ViT-L_32': configs.get_l32_config(),
  379. 'ViT-H_14': configs.get_h14_config(),
  380. 'R50-ViT-B_16': configs.get_r50_b16_config(),
  381. 'R50-ViT-L_16': configs.get_r50_l16_config(),
  382. 'testing': configs.get_testing(),
  383. }

下方代码为ResNetV2的网络结构

  1. import math
  2. from os.path import join as pjoin
  3. from collections import OrderedDict
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. def np2th(weights, conv=False):
  8. """Possibly convert HWIO to OIHW."""
  9. if conv:
  10. weights = weights.transpose([3, 2, 0, 1])
  11. return torch.from_numpy(weights)
  12. class StdConv2d(nn.Conv2d):
  13. def forward(self, x):
  14. w = self.weight
  15. v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
  16. w = (w - m) / torch.sqrt(v + 1e-5)
  17. return F.conv2d(x, w, self.bias, self.stride, self.padding,
  18. self.dilation, self.groups)
  19. def conv3x3(cin, cout, stride=1, groups=1, bias=False):
  20. return StdConv2d(cin, cout, kernel_size=3, stride=stride,
  21. padding=1, bias=bias, groups=groups)
  22. def conv1x1(cin, cout, stride=1, bias=False):
  23. return StdConv2d(cin, cout, kernel_size=1, stride=stride,
  24. padding=0, bias=bias)
  25. class PreActBottleneck(nn.Module):
  26. """Pre-activation (v2) bottleneck block.
  27. """
  28. def __init__(self, cin, cout=None, cmid=None, stride=1):
  29. super().__init__()
  30. cout = cout or cin
  31. cmid = cmid or cout//4
  32. self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
  33. self.conv1 = conv1x1(cin, cmid, bias=False)
  34. self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
  35. self.conv2 = conv3x3(cmid, cmid, stride, bias=False) # Original code has it on conv1!!
  36. self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
  37. self.conv3 = conv1x1(cmid, cout, bias=False)
  38. self.relu = nn.ReLU(inplace=True)
  39. if (stride != 1 or cin != cout):
  40. # Projection also with pre-activation according to paper.
  41. self.downsample = conv1x1(cin, cout, stride, bias=False)
  42. self.gn_proj = nn.GroupNorm(cout, cout)
  43. def forward(self, x):
  44. # Residual branch
  45. residual = x
  46. # 是否有下采样模块
  47. if hasattr(self, 'downsample'):
  48. residual = self.downsample(x)
  49. residual = self.gn_proj(residual)
  50. # Unit's branch
  51. y = self.relu(self.gn1(self.conv1(x)))
  52. y = self.relu(self.gn2(self.conv2(y)))
  53. y = self.gn3(self.conv3(y))
  54. y = self.relu(residual + y)
  55. return y
  56. def load_from(self, weights, n_block, n_unit):
  57. conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
  58. conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
  59. conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)
  60. gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
  61. gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])
  62. gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
  63. gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])
  64. gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
  65. gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])
  66. self.conv1.weight.copy_(conv1_weight)
  67. self.conv2.weight.copy_(conv2_weight)
  68. self.conv3.weight.copy_(conv3_weight)
  69. self.gn1.weight.copy_(gn1_weight.view(-1))
  70. self.gn1.bias.copy_(gn1_bias.view(-1))
  71. self.gn2.weight.copy_(gn2_weight.view(-1))
  72. self.gn2.bias.copy_(gn2_bias.view(-1))
  73. self.gn3.weight.copy_(gn3_weight.view(-1))
  74. self.gn3.bias.copy_(gn3_bias.view(-1))
  75. if hasattr(self, 'downsample'):
  76. proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
  77. proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
  78. proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])
  79. self.downsample.weight.copy_(proj_conv_weight)
  80. self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
  81. self.gn_proj.bias.copy_(proj_gn_bias.view(-1))
  82. class ResNetV2(nn.Module):
  83. """Implementation of Pre-activation (v2) ResNet mode."""
  84. def __init__(self, block_units, width_factor):
  85. super().__init__()
  86. width = int(64 * width_factor)
  87. self.width = width
  88. self.root = nn.Sequential(OrderedDict([
  89. ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
  90. ('gn', nn.GroupNorm(32, width, eps=1e-6)),
  91. ('relu', nn.ReLU(inplace=True)),
  92. # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
  93. ]))
  94. # [B,C,55,55]
  95. self.body = nn.Sequential(OrderedDict([
  96. ('block1', nn.Sequential(OrderedDict(
  97. [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
  98. [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
  99. ))),
  100. ('block2', nn.Sequential(OrderedDict(
  101. [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
  102. [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
  103. ))),
  104. ('block3', nn.Sequential(OrderedDict(
  105. [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
  106. [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
  107. ))),
  108. ]))
  109. def forward(self, x):
  110. # [B,3,224,224]
  111. features = []
  112. b, c, in_size, _ = x.size()
  113. x = self.root(x) # [B,C,112,112]
  114. features.append(x)
  115. x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x) # [B,C,55,55]
  116. for i in range(len(self.body)-1):
  117. x = self.body[i](x)
  118. right_size = int(in_size / 4 / (i+1))
  119. if x.size()[2] != right_size:
  120. pad = right_size - x.size()[2]
  121. assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
  122. feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
  123. feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
  124. else:
  125. feat = x
  126. features.append(feat)
  127. x = self.body[-1](x) # [B,16C,14,14]
  128. return x, features[::-1]

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/83840
推荐阅读
相关标签
  

闽ICP备14008679号