当前位置:   article > 正文

语义分割系列7-Attention Unet(pytorch实现)

attention unet

继前文UnetUnet++之后,本文将介绍Attention Unet。

Attention Unet地址,《Attention U-Net: Learning Where to Look for the Pancreas》。


AttentionUnet

Attention Unet发布于2018年,主要应用于医学领域的图像分割,全文中主要以肝脏的分割论证。

论文中心

Attention Unet主要的中心思想就是提出来Attention gate模块,使用soft-attention替代hard-attention,将attention集成到Unet的跳跃连接和上采样模块中,实现空间上的注意力机制。通过attention机制来抑制图像中的无关信息,突出局部的重要特征。

网络架构

图1 AttentionUnet模型

 Attention Unet的模型结构和Unet十分相像,只是增加了Attention Gate模块来对skip connection和upsampling层做attention机制(图2)。

图2 Attention Gate模块

在Attention Gate模块中,g和xl分别为skip connection的输出和下一层的输出,如图3。

图3 Attention Gate的输入

需要注意的是,在计算Wg和Wx后,对两者进行相加。但是,此时g的维度和xl的维度并不相等,则需要对g做下采样或对xl做上采样。(我倾向于对xl做上采样,因为在原本的Unet中,在Decoder就需要对下一层做上采样,所以,直接使用这个上采样结果可以减少网络计算)。

Wg和Wx经过相加,ReLU激活,1x1x1卷积,Sigmoid激活,生成一个权重信息,将这个权重与原始输入xl相乘,得到了对xl的attention激活。这就是Attenton Gate的思想。

Attenton Gate还有一个比较重要的特点是:这个权重可以经由网络学习!因为soft-attention是可微的,可以微分的attention就可以通过神经网络算出梯度并且前向传播和后向反馈来学习得到attention的权重。以此来学习更重要的特征。

模型复现

Attention Unet代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn import init
  5. def init_weights(net, init_type='normal', gain=0.02):
  6. def init_func(m):
  7. classname = m.__class__.__name__
  8. if hasattr(m, 'weight') and (classname.find('Conv') != -1
  9. or classname.find('Linear') != -1):
  10. if init_type == 'normal':
  11. init.normal_(m.weight.data, 0.0, gain)
  12. elif init_type == 'xavier':
  13. init.xavier_normal_(m.weight.data, gain=gain)
  14. elif init_type == 'kaiming':
  15. init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
  16. elif init_type == 'orthogonal':
  17. init.orthogonal_(m.weight.data, gain=gain)
  18. else:
  19. raise NotImplementedError(
  20. 'initialization method [%s] is not implemented' %
  21. init_type)
  22. if hasattr(m, 'bias') and m.bias is not None:
  23. init.constant_(m.bias.data, 0.0)
  24. elif classname.find('BatchNorm2d') != -1:
  25. init.normal_(m.weight.data, 1.0, gain)
  26. init.constant_(m.bias.data, 0.0)
  27. print('initialize network with %s' % init_type)
  28. net.apply(init_func)
  29. class conv_block(nn.Module):
  30. def __init__(self, ch_in, ch_out):
  31. super(conv_block, self).__init__()
  32. self.conv = nn.Sequential(
  33. nn.Conv2d(ch_in,
  34. ch_out,
  35. kernel_size=3,
  36. stride=1,
  37. padding=1,
  38. bias=True),
  39. nn.BatchNorm2d(ch_out),
  40. nn.ReLU(inplace=True),
  41. nn.Conv2d(ch_out,
  42. ch_out,
  43. kernel_size=3,
  44. stride=1,
  45. padding=1,
  46. bias=True),
  47. nn.BatchNorm2d(ch_out),
  48. nn.ReLU(inplace=True))
  49. def forward(self, x):
  50. x = self.conv(x)
  51. return x
  52. class up_conv(nn.Module):
  53. def __init__(self, ch_in, ch_out, convTranspose=True):
  54. super(up_conv, self).__init__()
  55. if convTranspose:
  56. self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_in,kernel_size=4,stride=2, padding=1)
  57. else:
  58. self.up = nn.Upsample(scale_factor=2)
  59. self.Conv = nn.Sequential(
  60. nn.Conv2d(ch_in,
  61. ch_out,
  62. kernel_size=3,
  63. stride=1,
  64. padding=1,
  65. bias=True),
  66. nn.BatchNorm2d(ch_out),
  67. nn.ReLU(inplace=True))
  68. def forward(self, x):
  69. x = self.up(x)
  70. x = self.Conv(x)
  71. return x
  72. class single_conv(nn.Module):
  73. def __init__(self, ch_in, ch_out):
  74. super(single_conv, self).__init__()
  75. self.conv = nn.Sequential(
  76. nn.Conv2d(ch_in,
  77. ch_out,
  78. kernel_size=3,
  79. stride=1,
  80. padding=1,
  81. bias=True),
  82. nn.BatchNorm2d(ch_out),
  83. nn.ReLU(inplace=True))
  84. def forward(self, x):
  85. x = self.conv(x)
  86. return x
  87. class Attention_block(nn.Module):
  88. def __init__(self, F_g, F_l, F_int):
  89. super(Attention_block, self).__init__()
  90. self.W_g = nn.Sequential(
  91. nn.Conv2d(F_g,
  92. F_int,
  93. kernel_size=1,
  94. stride=1,
  95. padding=0,
  96. bias=True),
  97. nn.BatchNorm2d(F_int))
  98. self.W_x = nn.Sequential(
  99. nn.Conv2d(F_l,
  100. F_int,
  101. kernel_size=1,
  102. stride=1,
  103. padding=0,
  104. bias=True),
  105. nn.BatchNorm2d(F_int))
  106. self.psi = nn.Sequential(
  107. nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
  108. nn.BatchNorm2d(1), nn.Sigmoid())
  109. self.relu = nn.ReLU(inplace=True)
  110. def forward(self, g, x):
  111. g1 = self.W_g(g)
  112. x1 = self.W_x(x)
  113. psi = self.relu(g1 + x1)
  114. psi = self.psi(psi)
  115. return x * psi
  116. class AttU_Net(nn.Module):
  117. """
  118. in_channel: input image channels
  119. num_classes: output class number
  120. channel_list: a channel list for adjust the model size
  121. checkpoint: 是否有checkpoint if False: call normal init
  122. convTranspose: 是否使用反卷积上采样。True: use nn.convTranspose Flase: use nn.Upsample
  123. """
  124. def __init__(self,
  125. in_channel=3,
  126. num_classes=1,
  127. channel_list=[64, 128, 256, 512, 1024],
  128. checkpoint=False,
  129. convTranspose=True):
  130. super(AttU_Net, self).__init__()
  131. self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
  132. self.Conv1 = conv_block(ch_in=in_channel, ch_out=channel_list[0])
  133. self.Conv2 = conv_block(ch_in=channel_list[0], ch_out=channel_list[1])
  134. self.Conv3 = conv_block(ch_in=channel_list[1], ch_out=channel_list[2])
  135. self.Conv4 = conv_block(ch_in=channel_list[2], ch_out=channel_list[3])
  136. self.Conv5 = conv_block(ch_in=channel_list[3], ch_out=channel_list[4])
  137. self.Up5 = up_conv(ch_in=channel_list[4], ch_out=channel_list[3], convTranspose=convTranspose)
  138. self.Att5 = Attention_block(F_g=channel_list[3],
  139. F_l=channel_list[3],
  140. F_int=channel_list[2])
  141. self.Up_conv5 = conv_block(ch_in=channel_list[4],
  142. ch_out=channel_list[3])
  143. self.Up4 = up_conv(ch_in=channel_list[3], ch_out=channel_list[2], convTranspose=convTranspose)
  144. self.Att4 = Attention_block(F_g=channel_list[2],
  145. F_l=channel_list[2],
  146. F_int=channel_list[1])
  147. self.Up_conv4 = conv_block(ch_in=channel_list[3],
  148. ch_out=channel_list[2])
  149. self.Up3 = up_conv(ch_in=channel_list[2], ch_out=channel_list[1], convTranspose=convTranspose)
  150. self.Att3 = Attention_block(F_g=channel_list[1],
  151. F_l=channel_list[1],
  152. F_int=64)
  153. self.Up_conv3 = conv_block(ch_in=channel_list[2],
  154. ch_out=channel_list[1])
  155. self.Up2 = up_conv(ch_in=channel_list[1], ch_out=channel_list[0], convTranspose=convTranspose)
  156. self.Att2 = Attention_block(F_g=channel_list[0],
  157. F_l=channel_list[0],
  158. F_int=channel_list[0] // 2)
  159. self.Up_conv2 = conv_block(ch_in=channel_list[1],
  160. ch_out=channel_list[0])
  161. self.Conv_1x1 = nn.Conv2d(channel_list[0],
  162. num_classes,
  163. kernel_size=1,
  164. stride=1,
  165. padding=0)
  166. if not checkpoint:
  167. init_weights(self)
  168. def forward(self, x):
  169. # encoder
  170. x1 = self.Conv1(x)
  171. x2 = self.Maxpool(x1)
  172. x2 = self.Conv2(x2)
  173. x3 = self.Maxpool(x2)
  174. x3 = self.Conv3(x3)
  175. x4 = self.Maxpool(x3)
  176. x4 = self.Conv4(x4)
  177. x5 = self.Maxpool(x4)
  178. x5 = self.Conv5(x5)
  179. # decoder
  180. d5 = self.Up5(x5)
  181. x4 = self.Att5(g=d5, x=x4)
  182. d5 = torch.cat((x4, d5), dim=1)
  183. d5 = self.Up_conv5(d5)
  184. d4 = self.Up4(d5)
  185. x3 = self.Att4(g=d4, x=x3)
  186. d4 = torch.cat((x3, d4), dim=1)
  187. d4 = self.Up_conv4(d4)
  188. d3 = self.Up3(d4)
  189. x2 = self.Att3(g=d3, x=x2)
  190. d3 = torch.cat((x2, d3), dim=1)
  191. d3 = self.Up_conv3(d3)
  192. d2 = self.Up2(d3)
  193. x1 = self.Att2(g=d2, x=x1)
  194. d2 = torch.cat((x1, d2), dim=1)
  195. d2 = self.Up_conv2(d2)
  196. d1 = self.Conv_1x1(d2)
  197. return d1

数据集

数据集依旧使用Camvid数据集,见Camvid数据集的构建和使用。

  1. # 导入库
  2. import os
  3. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  4. import torch
  5. import torch.nn as nn
  6. import torch.optim as optim
  7. import torch.nn.functional as F
  8. from torch import optim
  9. from torch.utils.data import Dataset, DataLoader, random_split
  10. from tqdm import tqdm
  11. import warnings
  12. warnings.filterwarnings("ignore")
  13. import os.path as osp
  14. import matplotlib.pyplot as plt
  15. from PIL import Image
  16. import numpy as np
  17. import albumentations as A
  18. from albumentations.pytorch.transforms import ToTensorV2
  19. torch.manual_seed(17)
  20. # 自定义数据集CamVidDataset
  21. class CamVidDataset(torch.utils.data.Dataset):
  22. """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
  23. Args:
  24. images_dir (str): path to images folder
  25. masks_dir (str): path to segmentation masks folder
  26. class_values (list): values of classes to extract from segmentation mask
  27. augmentation (albumentations.Compose): data transfromation pipeline
  28. (e.g. flip, scale, etc.)
  29. preprocessing (albumentations.Compose): data preprocessing
  30. (e.g. noralization, shape manipulation, etc.)
  31. """
  32. def __init__(self, images_dir, masks_dir):
  33. self.transform = A.Compose([
  34. A.Resize(224, 224),
  35. A.HorizontalFlip(),
  36. A.VerticalFlip(),
  37. A.Normalize(),
  38. ToTensorV2(),
  39. ])
  40. self.ids = os.listdir(images_dir)
  41. self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
  42. self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]
  43. def __getitem__(self, i):
  44. # read data
  45. image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
  46. mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
  47. image = self.transform(image=image,mask=mask)
  48. return image['image'], image['mask'][:,:,0]
  49. def __len__(self):
  50. return len(self.ids)
  51. # 设置数据集路径
  52. DATA_DIR = r'dataset\camvid' # 根据自己的路径来设置
  53. x_train_dir = os.path.join(DATA_DIR, 'train_images')
  54. y_train_dir = os.path.join(DATA_DIR, 'train_labels')
  55. x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
  56. y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
  57. train_dataset = CamVidDataset(
  58. x_train_dir,
  59. y_train_dir,
  60. )
  61. val_dataset = CamVidDataset(
  62. x_valid_dir,
  63. y_valid_dir,
  64. )
  65. train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,drop_last=True)
  66. val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True,drop_last=True)

模型训练

  1. model = AttentionUnet(num_classes=33).cuda()
  2. #model.load_state_dict(torch.load(r"checkpoints/Unet_100.pth"),strict=False)
  3. from d2l import torch as d2l
  4. from tqdm import tqdm
  5. import pandas as pd
  6. #损失函数选用多分类交叉熵损失函数
  7. lossf = nn.CrossEntropyLoss(ignore_index=255)
  8. #选用adam优化器来训练
  9. optimizer = optim.SGD(model.parameters(),lr=0.1)
  10. scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1, last_epoch=-1)
  11. #训练50轮
  12. epochs_num = 50
  13. def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,scheduler,
  14. devices=d2l.try_all_gpus()):
  15. timer, num_batches = d2l.Timer(), len(train_iter)
  16. animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
  17. legend=['train loss', 'train acc', 'test acc'])
  18. net = nn.DataParallel(net, device_ids=devices).to(devices[0])
  19. loss_list = []
  20. train_acc_list = []
  21. test_acc_list = []
  22. epochs_list = []
  23. time_list = []
  24. for epoch in range(num_epochs):
  25. # Sum of training loss, sum of training accuracy, no. of examples,
  26. # no. of predictions
  27. metric = d2l.Accumulator(4)
  28. for i, (features, labels) in enumerate(train_iter):
  29. timer.start()
  30. l, acc = d2l.train_batch_ch13(
  31. net, features, labels.long(), loss, trainer, devices)
  32. metric.add(l, acc, labels.shape[0], labels.numel())
  33. timer.stop()
  34. if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
  35. animator.add(epoch + (i + 1) / num_batches,
  36. (metric[0] / metric[2], metric[1] / metric[3],
  37. None))
  38. test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
  39. animator.add(epoch + 1, (None, None, test_acc))
  40. scheduler.step()
  41. # print(f'loss {metric[0] / metric[2]:.3f}, train acc '
  42. # f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
  43. # print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
  44. # f'{str(devices)}')
  45. print(f"epoch {epoch+1} --- loss {metric[0] / metric[2]:.3f} --- train acc {metric[1] / metric[3]:.3f} --- test acc {test_acc:.3f} --- cost time {timer.sum()}")
  46. #---------保存训练数据---------------
  47. df = pd.DataFrame()
  48. loss_list.append(metric[0] / metric[2])
  49. train_acc_list.append(metric[1] / metric[3])
  50. test_acc_list.append(test_acc)
  51. epochs_list.append(epoch+1)
  52. time_list.append(timer.sum())
  53. df['epoch'] = epochs_list
  54. df['loss'] = loss_list
  55. df['train_acc'] = train_acc_list
  56. df['test_acc'] = test_acc_list
  57. df['time'] = time_list
  58. df.to_excel("savefile/AttentionUnet_camvid1.xlsx")
  59. #----------------保存模型-------------------
  60. if np.mod(epoch+1, 5) == 0:
  61. torch.save(model.state_dict(), f'checkpoints/AttentionUnet_{epoch+1}.pth')

开始训练

train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num,scheduler)

训练结果


插在最后。

最近很多同学找我要代码,我有时候长时间不看就容易遗漏。我把代码和数据文件传到网盘上,供大家自行下载。

链接:https://pan.baidu.com/s/1taJlov4VvN-Nwp_xoUbgOA?pwd=yumi 
提取码:yumi 
--来自百度网盘超级会员V6的分享

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/69431
推荐阅读
相关标签
  

闽ICP备14008679号