赞
踩
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
UNet是医学图像分割领域经典的论文,因其结构像字母U
得名,本文的代码是对其他博主代码的细节上的增改,增加了测试代码。
下面是该博主的链接,包含了预训练模型:
UNet的Pytorch实现_Natuski_的博客-CSDN博客_pytorch unet
- import os
- import torchvision
- from PIL import Image
- from torch.utils.data import Dataset
- import torch
-
- class SEGData(Dataset):
- def __init__(self,path1,path2):
- '''
- 根据标注文件去取图片
- '''
- self.img_path=path1
- self.label_path=path2
- self.images = sorted(os.listdir(self.img_path))
- self.labels = sorted(os.listdir(self.label_path))
- # self.label_data=os.listdir(self.label_path)
- self.totensor=torchvision.transforms.ToTensor()
- # 一般而言,尺寸越大,训练效果越好,速度越慢
- self.resizer=torchvision.transforms.Resize((512,512))
-
- def __len__(self):
- return len(self.images)
-
- def __getitem__(self, i):
- '''
- 由于输出的图片的尺寸不同,我们需要转换为相同大小的图片。首先转换为正方形图片,然后缩放的同样尺度(256*256)。
- 否则dataloader会报错。
- '''
- # 取出图片路径
- img = Image.open(self.img_path + self.images[i])
- label = Image.open(self.label_path + self.labels[i])
- # img_name = os.path.join(self.label_path, self.label_data[item])
- # img_name = os.path.split(img_name)
- # img_name = img_name[-1]
- # img_name = img_name.split('.')
- # img_name = img_name[0] + '.png'
- # img_data = os.path.join(self.img_path, img_name)
- # label_data = os.path.join(self.label_path, self.label_data[item])
- # 将图片和标签都转为正方形
- # img = Image.open(img_data)
- # label = Image.open(label_data)
- w, h = img.size
- # 以最长边为基准,生成全0正方形矩阵
- slide = max(h, w)
- black_img = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
- black_label = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
- black_img.paste(img, (0, 0, int(w), int(h))) # patse在图中央和在左上角是一样的
- black_label.paste(label, (0, 0, int(w), int(h)))
- # 变为tensor,转换为统一大小512*512
- img = self.resizer(black_img)
- label = self.resizer(black_label)
- img = self.totensor(img)
- label = self.totensor(label)
- return img,label
- from __future__ import print_function, division
-
- import torch
- import torch.nn as nn
-
- class UNet(nn.Module):
- def __init__(self):
- super(UNet, self).__init__()
- out_channels=[2**(i+6) for i in range(5)] #[64, 128, 256, 512, 1024]
- #下采样
- self.d1=DownsampleLayer(3,out_channels[0])#3-64
- self.d2=DownsampleLayer(out_channels[0],out_channels[1])#64-128
- self.d3=DownsampleLayer(out_channels[1],out_channels[2])#128-256
- self.d4=DownsampleLayer(out_channels[2],out_channels[3])#256-512
- #上采样
- self.u1=UpSampleLayer(out_channels[3],out_channels[3])#512-1024-512
- self.u2=UpSampleLayer(out_channels[4],out_channels[2])#1024-512-256
- self.u3=UpSampleLayer(out_channels[3],out_channels[1])#512-256-128
- self.u4=UpSampleLayer(out_channels[2],out_channels[0])#256-128-64
- #输出
- self.o=nn.Sequential(
- nn.Conv2d(out_channels[1],out_channels[0],kernel_size=3,stride=1,padding=1),
- nn.BatchNorm2d(out_channels[0]),
- nn.ReLU(),
- nn.Conv2d(out_channels[0], out_channels[0], kernel_size=3, stride=1, padding=1),
- nn.BatchNorm2d(out_channels[0]),
- nn.ReLU(),
- nn.Conv2d(out_channels[0],3,3,1,1),
- nn.Sigmoid(),
- # BCELoss
- )
- def forward(self,x):
- out_1,out1=self.d1(x)
- out_2,out2=self.d2(out1)
- out_3,out3=self.d3(out2)
- out_4,out4=self.d4(out3)
- out5=self.u1(out4,out_4)
- out6=self.u2(out5,out_3)
- out7=self.u3(out6,out_2)
- out8=self.u4(out7,out_1)
- out=self.o(out8)
- return out
-
- # 下采样
- class DownsampleLayer(nn.Module):
- def __init__(self,in_ch,out_ch):
- super(DownsampleLayer, self).__init__()
- self.Conv_BN_ReLU_2=nn.Sequential(
- nn.Conv2d(in_channels=in_ch,out_channels=out_ch,kernel_size=3,stride=1,padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(),
- nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, stride=1,padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU()
- )
- self.downsample=nn.Sequential(
- nn.Conv2d(in_channels=out_ch,out_channels=out_ch,kernel_size=3,stride=2,padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU()
- )
-
- def forward(self,x):
- """
- :param x:
- :return: out输出到深层,out_2输入到下一层,
- """
- out=self.Conv_BN_ReLU_2(x)
- out_2=self.downsample(out)
- return out,out_2
-
- # 上采样
- class UpSampleLayer(nn.Module):
- def __init__(self,in_ch,out_ch):
- # 512-1024-512
- # 1024-512-256
- # 512-256-128
- # 256-128-64
- super(UpSampleLayer, self).__init__()
- self.Conv_BN_ReLU_2 = nn.Sequential(
- nn.Conv2d(in_channels=in_ch, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
- nn.BatchNorm2d(out_ch*2),
- nn.ReLU(),
- nn.Conv2d(in_channels=out_ch*2, out_channels=out_ch*2, kernel_size=3, stride=1,padding=1),
- nn.BatchNorm2d(out_ch*2),
- nn.ReLU()
- )
- self.upsample=nn.Sequential(
- nn.ConvTranspose2d(in_channels=out_ch*2,out_channels=out_ch,kernel_size=3,stride=2,padding=1,output_padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU()
- )
-
- def forward(self,x,out):
- '''
- :param x: 输入卷积层
- :param out:与上采样层进行cat
- :return:
- '''
- x_out=self.Conv_BN_ReLU_2(x)
- x_out=self.upsample(x_out)
- cat_out=torch.cat((x_out,out),dim=1)
- return cat_out
- import torch
- import torch.nn as nn
- from tensorboardX import SummaryWriter
- from torch.utils.data import DataLoader
- import os
- from torchvision.utils import save_image
- from min_unet.Model import UNet
- from min_unet.dataset import SEGData
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
-
- def main(path1,path2,EPOCH,Batch):
- net = UNet().cuda()
- optimizer = torch.optim.Adam(net.parameters())
- loss_func = nn.BCELoss()
- data = SEGData(path1,path2)
- dataloader = DataLoader(data, batch_size=Batch, shuffle=True, num_workers=0, drop_last=True)
- summary = SummaryWriter(r'Log')
- print('load net')
- net.load_state_dict(torch.load('SAVE/Unet.pt'))
- print('load success')
- for epoch in range(EPOCH):
- print('开始第{}轮'.format(epoch))
- net.train()
- for i, (img, label) in enumerate(dataloader):
- img = img.cuda()
- label = label.cuda()
- img_out = net(img)
- loss = loss_func(img_out, label)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- summary.add_scalar('bceloss', loss, i)
-
- torch.save(net.state_dict(), r'SAVE/Unet.pt')
-
- img, label = data[2]
- img = torch.unsqueeze(img, dim=0).cuda()
- net.eval()
- out = net(img)
- if not os.path.exists(r"Log_imgs"):
- os.mkdir(r"Log_imgs")
- if epoch%10==0:
- save_image(out, 'Log_imgs/segimg_{}——.png'.format(epoch, i), nrow=1, scale_each=True)
- print(f"第{epoch}轮train_loss={loss.item()}")
- print('第{}轮结束'.format(epoch))
-
-
- if __name__=='__main__':
- path1 = r'../data/imgs/'#训练集图像
- path2 = r'../data/masks/'#训练集图像标签
- EPOCH = 11
- Batch = 2
- main(path1,path2,EPOCH,Batch)
- import torch
- import torchvision
- import os
- from torchvision.utils import save_image
- from min_unet.Model import UNet
- from PIL import Image
-
-
- def test(input_path):
- net = UNet().cuda()
- weight=r'SAVE/Unet.pt'
- if os.path.exists(weight):
- net.load_state_dict(torch.load(weight))
- print("successful")
- else:
- print("no")
-
- if not os.path.exists(r"Test_imgs"):
- os.mkdir(r"Test_imgs")
-
- for file in os.listdir(input_path):
- f=file.split('.')[0]
- path = os.path.join(input_path,file)
- img = Image.open(path)
- w, h = img.size
- slide = max(h, w)
- # img=transform(path)
- black_img = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
- # black_label = torchvision.transforms.ToPILImage()(torch.zeros(3, slide, slide))
- black_img.paste(img, (0, 0, int(w), int(h))) # patse在图中央和在左上角是一样的
- # black_label.paste(label, (0, 0, int(w), int(h)))
- tensor_test = torchvision.transforms.ToTensor()
- image = tensor_test(black_img)
- img = torch.unsqueeze(image, dim=0).cuda()
- net.eval()
- out = net(img)
- save_image(out, f'Test_imgs/segimg_{f}.png', nrow=1, scale_each=True)
-
- if __name__=="__main__":
- input_path = r"../data/test_imgs"
- test(input_path)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。