当前位置:   article > 正文

Unet基础代码(修补版)_unet训练测试代码

unet训练测试代码

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

1.概述

UNet是医学图像分割领域经典的论文,因其结构像字母U得名,本文的代码是对其他博主代码的细节上的增改,增加了测试代码。

下面是该博主的链接,包含了预训练模型

UNet的Pytorch实现_Natuski_的博客-CSDN博客_pytorch unet

一、dataset.py

  1. import os
  2. import torchvision
  3. from PIL import Image
  4. from torch.utils.data import Dataset
  5. import torch
  6. class SEGData(Dataset):
  7. def __init__(self,path1,path2):
  8. '''
  9. 根据标注文件去取图片
  10. '''
  11. self.img_path=path1
  12. self.label_path=path2
  13. self.images = sorted(os.listdir(self.img_path))
  14. self.labels = sorted(os.listdir(self.label_path))
  15. # self.label_data=os.listdir(self.label_path)
  16. self.totensor=torchvision.transforms.ToTensor()
  17. # 一般而言,尺寸越大,训练效果越好,速度越慢
  18. self.resizer=torchvision.transforms.Resize((512,512))
  19. def __len__(self):
  20. return len(self.images)
  21. def __getitem__(self, i):
  22. '''
  23. 由于输出的图片的尺寸不同,我们需要转换为相同大小的图片。首先转换为正方形图片,然后缩放的同样尺度(256*256)。
  24. 否则dataloader会报错。
  25. '''
  26. # 取出图片路径
  27. img = Image.open(self.img_path + self.images[i])
  28. label = Image.open(self.label_path + self.labels[i])
  29. # img_name = os.path.join(self.label_path, self.label_data[item])
  30. # img_name = os.path.split(img_name)
  31. # img_name = img_name[-1]
  32. # img_name = img_name.split('.')
  33. # img_name = img_name[0] + '.png'
  34. # img_data = os.path.join(self.img_path, img_name)
  35. # label_data = os.path.join(self.label_path, self.label_data[item])
  36. # 将图片和标签都转为正方形
  37. # img = Image.open(img_data)
  38. # label = Image.open(label_data)
  39. w, h = img.size
  40. # 以最长边为基准,生成全0正方形矩阵
  41. slide = max(h, w)
  42. black_img = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
  43. black_label = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
  44. black_img.paste(img, (0, 0, int(w), int(h))) # patse在图中央和在左上角是一样的
  45. black_label.paste(label, (0, 0, int(w), int(h)))
  46. # 变为tensor,转换为统一大小512*512
  47. img = self.resizer(black_img)
  48. label = self.resizer(black_label)
  49. img = self.totensor(img)
  50. label = self.totensor(label)
  51. return img,label

二、Model.py

  1. from __future__ import print_function, division
  2. import torch
  3. import torch.nn as nn
  4. class UNet(nn.Module):
  5. def __init__(self):
  6. super(UNet, self).__init__()
  7. out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024]
  8. #下采样
  9. self.d1=DownsampleLayer(3,out_channels[0])#3-64
  10. self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128
  11. self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256
  12. self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512
  13. #上采样
  14. self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512
  15. self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256
  16. self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128
  17. self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64
  18. #输出
  19. self.o=nn.Sequential(
  20. nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1),
  21. nn.BatchNorm2d(out_channels[0]),
  22. nn.ReLU(),
  23. nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1),
  24. nn.BatchNorm2d(out_channels[0]),
  25. nn.ReLU(),
  26. nn.Conv2d(out_channels[0],3,3,1,1),
  27. nn.Sigmoid(),
  28. # BCELoss
  29. )
  30. def forward(self,x):
  31. out_1,out1=self.d1(x)
  32. out_2,out2=self.d2(out1)
  33. out_3,out3=self.d3(out2)
  34. out_4,out4=self.d4(out3)
  35. out5=self.u1(out4,out_4)
  36. out6=self.u2(out5,out_3)
  37. out7=self.u3(out6,out_2)
  38. out8=self.u4(out7,out_1)
  39. out=self.o(out8)
  40. return out
  41. # 下采样
  42. class DownsampleLayer(nn.Module):
  43. def __init__(self,in_ch,out_ch):
  44. super(DownsampleLayer, self).__init__()
  45. self.Conv_BN_ReLU_2=nn.Sequential(
  46. nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1),
  47. nn.BatchNorm2d(out_ch),
  48. nn.ReLU(),
  49. nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1),
  50. nn.BatchNorm2d(out_ch),
  51. nn.ReLU()
  52. )
  53. self.downsample=nn.Sequential(
  54. nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1),
  55. nn.BatchNorm2d(out_ch),
  56. nn.ReLU()
  57. )
  58. def forward(self,x):
  59. """
  60. :param x:
  61. :return: out输出到深层,out_2输入到下一层,
  62. """
  63. out=self.Conv_BN_ReLU_2(x)
  64. out_2=self.downsample(out)
  65. return out,out_2
  66. # 上采样
  67. class UpSampleLayer(nn.Module):
  68. def __init__(self,in_ch,out_ch):
  69. # 512-1024-512
  70. # 1024-512-256
  71. # 512-256-128
  72. # 256-128-64
  73. super(UpSampleLayer, self).__init__()
  74. self.Conv_BN_ReLU_2 = nn.Sequential(
  75. nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
  76. nn.BatchNorm2d(out_ch*2),
  77. nn.ReLU(),
  78. nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
  79. nn.BatchNorm2d(out_ch*2),
  80. nn.ReLU()
  81. )
  82. self.upsample=nn.Sequential(
  83. nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
  84. nn.BatchNorm2d(out_ch),
  85. nn.ReLU()
  86. )
  87. def forward(self,x,out):
  88. '''
  89. :param x: 输入卷积层
  90. :param out:与上采样层进行cat
  91. :return:
  92. '''
  93. x_out=self.Conv_BN_ReLU_2(x)
  94. x_out=self.upsample(x_out)
  95. cat_out=torch.cat((x_out,out),dim=1)
  96. return cat_out

三、train.py

  1. import torch
  2. import torch.nn as nn
  3. from tensorboardX import SummaryWriter
  4. from torch.utils.data import DataLoader
  5. import os
  6. from torchvision.utils import save_image
  7. from min_unet.Model import UNet
  8. from min_unet.dataset import SEGData
  9. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  10. def main(path1,path2,EPOCH,Batch):
  11. net = UNet().cuda()
  12. optimizer = torch.optim.Adam(net.parameters())
  13. loss_func = nn.BCELoss()
  14. data = SEGData(path1,path2)
  15. dataloader = DataLoader(data, batch_size=Batch, shuffle=True, num_workers=0, drop_last=True)
  16. summary = SummaryWriter(r'Log')
  17. print('load net')
  18. net.load_state_dict(torch.load('SAVE/Unet.pt'))
  19. print('load success')
  20. for epoch in range(EPOCH):
  21. print('开始第{}轮'.format(epoch))
  22. net.train()
  23. for i, (img, label) in enumerate(dataloader):
  24. img = img.cuda()
  25. label = label.cuda()
  26. img_out = net(img)
  27. loss = loss_func(img_out, label)
  28. optimizer.zero_grad()
  29. loss.backward()
  30. optimizer.step()
  31. summary.add_scalar('bceloss', loss, i)
  32. torch.save(net.state_dict(), r'SAVE/Unet.pt')
  33. img, label = data[2]
  34. img = torch.unsqueeze(img, dim=0).cuda()
  35. net.eval()
  36. out = net(img)
  37. if not os.path.exists(r"Log_imgs"):
  38. os.mkdir(r"Log_imgs")
  39. if epoch%10==0:
  40. save_image(out, 'Log_imgs/segimg_{}——.png'.format(epoch, i), nrow=1, scale_each=True)
  41. print(f"第{epoch}轮train_loss={loss.item()}")
  42. print('第{}轮结束'.format(epoch))
  43. if __name__=='__main__':
  44. path1 = r'../data/imgs/'#训练集图像
  45. path2 = r'../data/masks/'#训练集图像标签
  46. EPOCH = 11
  47. Batch = 2
  48. main(path1,path2,EPOCH,Batch)

四、test.py

  1. import torch
  2. import torchvision
  3. import os
  4. from torchvision.utils import save_image
  5. from min_unet.Model import UNet
  6. from PIL import Image
  7. def test(input_path):
  8. net = UNet().cuda()
  9. weight=r'SAVE/Unet.pt'
  10. if os.path.exists(weight):
  11. net.load_state_dict(torch.load(weight))
  12. print("successful")
  13. else:
  14. print("no")
  15. if not os.path.exists(r"Test_imgs"):
  16. os.mkdir(r"Test_imgs")
  17. for file in os.listdir(input_path):
  18. f=file.split('.')[0]
  19. path = os.path.join(input_path,file)
  20. img = Image.open(path)
  21. w, h = img.size
  22. slide = max(h, w)
  23. # img=transform(path)
  24. black_img = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
  25. # black_label = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
  26. black_img.paste(img, (0, 0, int(w), int(h))) # patse在图中央和在左上角是一样的
  27. # black_label.paste(label, (0, 0, int(w), int(h)))
  28. tensor_test = torchvision.transforms.ToTensor()
  29. image = tensor_test(black_img)
  30. img = torch.unsqueeze(image, dim=0).cuda()
  31. net.eval()
  32. out = net(img)
  33. save_image(out, f'Test_imgs/segimg_{f}.png', nrow=1, scale_each=True)
  34. if __name__=="__main__":
  35. input_path = r"../data/test_imgs"
  36. test(input_path)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/87944
推荐阅读
相关标签
  

闽ICP备14008679号