当前位置:   article > 正文

Vit-V-Net pytorch 代码理解与分析_config.patches["grid"]

config.patches["grid"]

论文题目:ViT-V-Net: Vision Transformer for Unsupervised Volumetric Medical Image Registration

源码链接:https://github.com/junyuchen245/ViT-V-Net_for_3D_Image_Registration_Pytorch

图1 整体框架
图2 block
图4 代码调用关系

config.py:

配置初始的参数

  1. import ml_collections
  2. def get_3DReg_config():
  3. config = ml_collections.ConfigDict()
  4. config.patches = ml_collections.ConfigDict({'size': (8, 8, 8)})
  5. config.patches.grid = (8, 8, 8)
  6. config.hidden_size = 252
  7. config.transformer = ml_collections.ConfigDict()
  8. config.transformer.mlp_dim = 3072
  9. config.transformer.num_heads = 12
  10. config.transformer.num_layers = 12
  11. config.transformer.attention_dropout_rate = 0.0
  12. config.transformer.dropout_rate = 0.1
  13. config.patch_size = 8
  14. config.conv_first_channel = 512
  15. config.encoder_channels = (16, 32, 32)
  16. config.down_factor = 2
  17. config.down_num = 2
  18. config.decoder_channels = (96, 48, 32, 32, 16)
  19. config.skip_channels = (32, 32, 32, 32, 16)
  20. config.n_dims = 3
  21. config.n_skip = 5
  22. return config

models.py:

Multi-head attention

图3  Transformer Encoder

 Attention理论上是上图中的橙色部分

  1. class Attention(nn.Module):
  2. def __init__(self, config, vis):
  3. super(Attention, self).__init__()
  4. self.vis = vis
  5. self.num_attention_heads = config.transformer["num_heads"]
  6. self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
  7. self.all_head_size = self.num_attention_heads * self.attention_head_size
  8. self.query = Linear(config.hidden_size, self.all_head_size)
  9. self.key = Linear(config.hidden_size, self.all_head_size)
  10. self.value = Linear(config.hidden_size, self.all_head_size)
  11. self.out = Linear(config.hidden_size, config.hidden_size)
  12. self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
  13. self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
  14. self.softmax = Softmax(dim=-1)
  15. def transpose_for_scores(self, x):
  16. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  17. x = x.view(*new_x_shape)
  18. return x.permute(0, 2, 1, 3)
  19. def forward(self, hidden_states):
  20. mixed_query_layer = self.query(hidden_states)
  21. mixed_key_layer = self.key(hidden_states)
  22. mixed_value_layer = self.value(hidden_states)
  23. query_layer = self.transpose_for_scores(mixed_query_layer)
  24. key_layer = self.transpose_for_scores(mixed_key_layer)
  25. value_layer = self.transpose_for_scores(mixed_value_layer)
  26. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  27. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  28. attention_probs = self.softmax(attention_scores)
  29. weights = attention_probs if self.vis else None
  30. attention_probs = self.attn_dropout(attention_probs)
  31. context_layer = torch.matmul(attention_probs, value_layer)
  32. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  33. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  34. context_layer = context_layer.view(*new_context_layer_shape)
  35. attention_output = self.out(context_layer)
  36. attention_output = self.proj_dropout(attention_output)
  37. return attention_output, weights

Mlp是图三中的黄色部分 ,前向神经网络

  1. class Mlp(nn.Module):#前向神经网络
  2. def __init__(self, config):
  3. super(Mlp, self).__init__()
  4. self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
  5. self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
  6. self.act_fn = ACT2FN["gelu"]
  7. self.dropout = Dropout(config.transformer["dropout_rate"])
  8. self._init_weights()
  9. def _init_weights(self):
  10. nn.init.xavier_uniform_(self.fc1.weight)
  11. nn.init.xavier_uniform_(self.fc2.weight)
  12. nn.init.normal_(self.fc1.bias, std=1e-6)
  13. nn.init.normal_(self.fc2.bias, std=1e-6)
  14. def forward(self, x):
  15. x = self.fc1(x)
  16. x = self.act_fn(x)
  17. x = self.dropout(x)
  18. x = self.fc2(x)
  19. x = self.dropout(x)
  20. return x

Embeddings是下图的部分

embedding这里是用cov3d来进行patch的,输出的是210块,

(B, n_patch, hidden)=([2, 210, 252]),经过transformer也是这个尺寸。
  1. class Embeddings(nn.Module):
  2. """Construct the embeddings from patch, position embeddings.
  3. """
  4. def __init__(self, config, img_size):
  5. super(Embeddings, self).__init__()
  6. self.config = config
  7. down_factor = config.down_factor
  8. patch_size = _triple(config.patches["size"])
  9. n_patches = int((img_size[0]/2**down_factor// patch_size[0]) * (img_size[1]/2**down_factor// patch_size[1]) * (img_size[2]/2**down_factor// patch_size[2]))
  10. self.hybrid_model = CNNEncoder(config, n_channels=2)
  11. in_channels = config['encoder_channels'][-1]
  12. self.patch_embeddings = Conv3d(in_channels=in_channels,
  13. out_channels=config.hidden_size,
  14. kernel_size=patch_size,
  15. stride=patch_size)
  16. self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))
  17. self.dropout = Dropout(config.transformer["dropout_rate"])
  18. def forward(self, x):
  19. x, features = self.hybrid_model(x)
  20. x = self.patch_embeddings(x) # (B, hidden. n_patches^(1/2), n_patches^(1/2))
  21. x = x.flatten(2)
  22. x = x.transpose(-1, -2) # (B, n_patches, hidden)
  23. embeddings = x + self.position_embeddings
  24. embeddings = self.dropout(embeddings)
  25. return embeddings, features
Block就是图3中的灰色框内的卷积层部分,灰色框并不是全部的transformer,而且block的类里不是12个块,只定义了一个块,12个块的循环在Ecoder类中定义

  1. class Block(nn.Module):
  2. def __init__(self, config, vis):
  3. super(Block, self).__init__()
  4. self.hidden_size = config.hidden_size
  5. self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
  6. self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
  7. self.ffn = Mlp(config)
  8. self.attn = Attention(config, vis)
  9. def forward(self, x):
  10. h = x
  11. x = self.attention_norm(x)
  12. x, weights = self.attn(x)
  13. x = x + h
  14. h = x
  15. x = self.ffn_norm(x)
  16. x = self.ffn(x)
  17. x = x + h
  18. return x, weights
encoder就是12个block放一起
  1. class Encoder(nn.Module):
  2. def __init__(self, config, vis):
  3. super(Encoder, self).__init__()
  4. self.vis = vis
  5. self.layer = nn.ModuleList()
  6. self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
  7. for _ in range(config.transformer["num_layers"]):
  8. layer = Block(config, vis)
  9. self.layer.append(copy.deepcopy(layer))
  10. def forward(self, hidden_states):
  11. attn_weights = []
  12. for layer_block in self.layer:
  13. hidden_states, weights = layer_block(hidden_states)
  14. if self.vis:
  15. attn_weights.append(weights)
  16. encoded = self.encoder_norm(hidden_states)
  17. return encoded, attn_weights
DecoderBlock是图1和图2的绿色部分, 用于DecoderCup,

  1. class DecoderBlock(nn.Module):
  2. def __init__(
  3. self,
  4. in_channels,
  5. out_channels,
  6. skip_channels=0,
  7. use_batchnorm=True,
  8. ):
  9. super().__init__()
  10. self.conv1 = Conv3dReLU(
  11. in_channels + skip_channels,
  12. out_channels,
  13. kernel_size=3,
  14. padding=1,
  15. use_batchnorm=use_batchnorm,
  16. )
  17. self.conv2 = Conv3dReLU(
  18. out_channels,
  19. out_channels,
  20. kernel_size=3,
  21. padding=1,
  22. use_batchnorm=use_batchnorm,
  23. )
  24. self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
  25. def forward(self, x, skip=None):
  26. x = self.up(x)
  27. if skip is not None:
  28. x = torch.cat([x, skip], dim=1)
  29. x = self.conv1(x)
  30. x = self.conv2(x)
  31. return x
 DecoderCup是整个的解码过程

  1. class DecoderCup(nn.Module):
  2. def __init__(self, config, img_size):
  3. super().__init__()
  4. self.config = config
  5. self.down_factor = config.down_factor
  6. head_channels = config.conv_first_channel
  7. self.img_size = img_size
  8. self.conv_more = Conv3dReLU(
  9. config.hidden_size,
  10. head_channels,
  11. kernel_size=3,
  12. padding=1,
  13. use_batchnorm=True,
  14. )
  15. decoder_channels = config.decoder_channels
  16. in_channels = [head_channels] + list(decoder_channels[:-1])
  17. out_channels = decoder_channels
  18. self.patch_size = _triple(config.patches["size"])
  19. skip_channels = self.config.skip_channels
  20. blocks = [
  21. DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
  22. ]
  23. self.blocks = nn.ModuleList(blocks)
Transformer是整个图3的组成,图1中的橙色部分
  1. class Transformer(nn.Module):
  2. def __init__(self, config, img_size, vis):
  3. super(Transformer, self).__init__()
  4. self.embeddings = Embeddings(config, img_size=img_size)
  5. self.encoder = Encoder(config, vis)
  6. def forward(self, input_ids):
  7. embedding_output, features = self.embeddings(input_ids)
  8. encoded, attn_weights = self.encoder(embedding_output) # (B, n_patch, hidden)
  9. return encoded, attn_weights, features
SpatialTransformer这里和VM一模一样,是图1中的Spatial Transformer蓝色部分,用于使移动图像发生形变。
  1. class SpatialTransformer(nn.Module):
  2. """
  3. N-D Spatial Transformer
  4. Obtained from https://github.com/voxelmorph/voxelmorph
  5. """
  6. def __init__(self, size, mode='bilinear'):
  7. super().__init__()
  8. self.mode = mode
  9. # create sampling grid
  10. vectors = [torch.arange(0, s) for s in size]
  11. grids = torch.meshgrid(vectors)
  12. grid = torch.stack(grids)
  13. grid = torch.unsqueeze(grid, 0)
  14. grid = grid.type(torch.FloatTensor)
  15. # registering the grid as a buffer cleanly moves it to the GPU, but it also
  16. # adds it to the state dict. this is annoying since everything in the state dict
  17. # is included when saving weights to disk, so the model files are way bigger
  18. # than they need to be. so far, there does not appear to be an elegant solution.
  19. # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
  20. self.register_buffer('grid', grid)
  21. def forward(self, src, flow):
  22. # new locations
  23. #print("self.grid.shape", self.grid.shape)
  24. #print( "flow.shape", flow.shape )
  25. new_locs = self.grid + flow
  26. shape = flow.shape[2:]
  27. # need to normalize grid values to [-1, 1] for resampler
  28. for i in range(len(shape)):
  29. new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
  30. # move channels dim to last position
  31. # also not sure why, but the channels need to be reversed
  32. if len(shape) == 2:
  33. new_locs = new_locs.permute(0, 2, 3, 1)
  34. new_locs = new_locs[..., [1, 0]]
  35. elif len(shape) == 3:
  36. new_locs = new_locs.permute(0, 2, 3, 4, 1)
  37. new_locs = new_locs[..., [2, 1, 0]]
  38. return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode)
DoubleConv双重卷积,就是2层卷积,代码没有用leakyrelu,只用了relu

  1. class DoubleConv(nn.Module):
  2. """(convolution => [BN] => ReLU) * 2"""
  3. def __init__(self, in_channels, out_channels, mid_channels=None):
  4. super().__init__()
  5. if not mid_channels:
  6. mid_channels = out_channels
  7. self.double_conv = nn.Sequential(
  8. nn.Conv3d(in_channels, mid_channels, kernel_size=3, padding=1),
  9. nn.ReLU(inplace=True),
  10. nn.Conv3d(mid_channels, out_channels, kernel_size=3, padding=1),
  11. nn.ReLU(inplace=True)
  12. )
  13. def forward(self, x):
  14. return self.double_conv(x)
Down在2层卷积后加入池化层,池化层可以有效的缩小参数矩阵的尺寸,从而减少最后连接层的中的参数数量。所以加入池化层可以加快计算速度和防止过拟合的作用
  1. class Down(nn.Module):
  2. """Downscaling with maxpool then double conv"""
  3. def __init__(self, in_channels, out_channels):
  4. super().__init__()
  5. self.maxpool_conv = nn.Sequential(
  6. nn.MaxPool3d(2),
  7. DoubleConv(in_channels, out_channels)
  8. )
  9. def forward(self, x):
  10. return self.maxpool_conv(x)
 Conv3dReLU 用于解码中 ,是一个带有relu激活函数的3D卷积,为什么归一化这里,代码用BatchNorm?而图中是instance norm。而且也没用leakyrelu
  1. class Conv3dReLU(nn.Sequential):
  2. def __init__(
  3. self,
  4. in_channels,
  5. out_channels,
  6. kernel_size,
  7. padding=0,
  8. stride=1,
  9. use_batchnorm=True,
  10. ):
  11. conv = nn.Conv3d(
  12. in_channels,
  13. out_channels,
  14. kernel_size,
  15. stride=stride,
  16. padding=padding,
  17. bias=not (use_batchnorm),
  18. )
  19. relu = nn.ReLU(inplace=True)
  20. bn = nn.BatchNorm3d(out_channels)
  21. super(Conv3dReLU, self).__init__(conv, bn, relu)
CNNEncoder,就是这3层卷积,输出权重和一个特征图,之后接着embedding

 疑问:为什么图中每层只有一次池化,这里又加入了三个池化?而且上图中instance norm在代码里根本没有

  1. class CNNEncoder(nn.Module):
  2. def __init__(self, config, n_channels=2):
  3. super(CNNEncoder, self).__init__()
  4. self.n_channels = n_channels
  5. decoder_channels = config.decoder_channels
  6. encoder_channels = config.encoder_channels
  7. self.down_num = config.down_num
  8. self.inc = DoubleConv(n_channels, encoder_channels[0])
  9. self.down1 = Down(encoder_channels[0], encoder_channels[1])
  10. self.down2 = Down(encoder_channels[1], encoder_channels[2])
  11. self.width = encoder_channels[-1]
  12. def forward(self, x):
  13. features = []
  14. x1 = self.inc(x)
  15. features.append(x1)
  16. x2 = self.down1(x1)
  17. features.append(x2)
  18. feats = self.down2(x2)
  19. features.append(feats)
  20. feats_down = feats
  21. for i in range(self.down_num):
  22. feats_down = nn.MaxPool3d(2)(feats_down)
  23. features.append(feats_down)
  24. return feats, features[::-1]
RegistrationHead配准,用于主框架中
  1. class RegistrationHead(nn.Sequential):
  2. def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
  3. conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
  4. conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape))
  5. conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape))
  6. super().__init__(conv3d)
ViTVNet:整体网络框架
  1. class ViTVNet(nn.Module):
  2. def __init__(self, config, img_size=(64, 256, 256), int_steps=7, vis=False, mode='bilinear'):
  3. super(ViTVNet, self).__init__()
  4. self.transformer = Transformer(config, img_size, vis)
  5. self.decoder = DecoderCup(config, img_size)
  6. self.reg_head = RegistrationHead(
  7. in_channels=config.decoder_channels[-1],
  8. out_channels=config['n_dims'],
  9. kernel_size=3,
  10. )
  11. self.spatial_trans = SpatialTransformer(img_size, mode)
  12. self.config = config
  13. #self.integrate = VecInt(img_size, int_steps)
  14. def forward(self, x):
  15. source = x[:,0:1,:,:]
  16. x, attn_weights, features = self.transformer(x) # (B, n_patch, hidden)
  17. x = self.decoder(x, features)
  18. flow = self.reg_head(x)
  19. #flow = self.integrate(x)
  20. #img = x[0].cuda( )
  21. #flow = x[1].cuda( )
  22. out = self.spatial_trans(source, flow)
  23. #out = self.spatial_trans( source, flow )
  24. return out, flow

                
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/83958
推荐阅读
相关标签
  

闽ICP备14008679号