当前位置:   article > 正文

pytorch版Unet实现医学图像分割_基于pyotrch+pyqt5+unet的医学影像分割可视化

基于pyotrch+pyqt5+unet的医学影像分割可视化

源码和数据已上传至github,方便下载使用。

GitHub - Z-XQ/unet_pytorch: using pytorch to implement unet network for liver image segmentation.using pytorch to implement unet network for liver image segmentation. - GitHub - Z-XQ/unet_pytorch: using pytorch to implement unet network for liver image segmentation.icon-default.png?t=N7T8https://github.com/Z-XQ/unet_pytorch

1 model

  1. import torch.nn as nn
  2. import torch
  3. from torch import autograd
  4. class DoubleConv(nn.Module):
  5. def __init__(self, in_ch, out_ch):
  6. super(DoubleConv, self).__init__()
  7. self.conv = nn.Sequential(
  8. nn.Conv2d(in_ch, out_ch, 3, padding=1),
  9. nn.BatchNorm2d(out_ch),
  10. nn.ReLU(inplace=True),
  11. nn.Conv2d(out_ch, out_ch, 3, padding=1),
  12. nn.BatchNorm2d(out_ch),
  13. nn.ReLU(inplace=True)
  14. )
  15. def forward(self, input):
  16. return self.conv(input)
  17. class Unet(nn.Module):
  18. def __init__(self,in_ch,out_ch):
  19. super(Unet, self).__init__()
  20. self.conv1 = DoubleConv(in_ch, 64)
  21. self.pool1 = nn.MaxPool2d(2)
  22. self.conv2 = DoubleConv(64, 128)
  23. self.pool2 = nn.MaxPool2d(2)
  24. self.conv3 = DoubleConv(128, 256)
  25. self.pool3 = nn.MaxPool2d(2)
  26. self.conv4 = DoubleConv(256, 512)
  27. self.pool4 = nn.MaxPool2d(2)
  28. self.conv5 = DoubleConv(512, 1024)
  29. self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
  30. self.conv6 = DoubleConv(1024, 512)
  31. self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
  32. self.conv7 = DoubleConv(512, 256)
  33. self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
  34. self.conv8 = DoubleConv(256, 128)
  35. self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
  36. self.conv9 = DoubleConv(128, 64)
  37. self.conv10 = nn.Conv2d(64,out_ch, 1)
  38. def forward(self,x):
  39. c1=self.conv1(x)
  40. p1=self.pool1(c1)
  41. c2=self.conv2(p1)
  42. p2=self.pool2(c2)
  43. c3=self.conv3(p2)
  44. p3=self.pool3(c3)
  45. c4=self.conv4(p3)
  46. p4=self.pool4(c4)
  47. c5=self.conv5(p4)
  48. up_6= self.up6(c5)
  49. merge6 = torch.cat([up_6, c4], dim=1)
  50. c6=self.conv6(merge6)
  51. up_7=self.up7(c6)
  52. merge7 = torch.cat([up_7, c3], dim=1)
  53. c7=self.conv7(merge7)
  54. up_8=self.up8(c7)
  55. merge8 = torch.cat([up_8, c2], dim=1)
  56. c8=self.conv8(merge8)
  57. up_9=self.up9(c8)
  58. merge9=torch.cat([up_9,c1],dim=1)
  59. c9=self.conv9(merge9)
  60. c10=self.conv10(c9)
  61. out = nn.Sigmoid()(c10)
  62. return out

2 dataset

  1. import torch.utils.data as data
  2. import PIL.Image as Image
  3. import os
  4. def make_dataset(root):
  5. imgs=[]
  6. n=len(os.listdir(root))//2
  7. for i in range(n):
  8. img=os.path.join(root,"%03d.png"%i)
  9. mask=os.path.join(root,"%03d_mask.png"%i)
  10. imgs.append((img,mask))
  11. return imgs
  12. class LiverDataset(data.Dataset):
  13. def __init__(self, root, transform=None, target_transform=None):
  14. imgs = make_dataset(root)
  15. self.imgs = imgs
  16. self.transform = transform
  17. self.target_transform = target_transform
  18. def __getitem__(self, index):
  19. x_path, y_path = self.imgs[index]
  20. img_x = Image.open(x_path)
  21. img_y = Image.open(y_path)
  22. if self.transform is not None:
  23. img_x = self.transform(img_x)
  24. if self.target_transform is not None:
  25. img_y = self.target_transform(img_y)
  26. return img_x, img_y
  27. def __len__(self):
  28. return len(self.imgs)

3 main

  1. import numpy as np
  2. import torch
  3. import argparse
  4. from torch.utils.data import DataLoader
  5. from torch import autograd, optim
  6. from torchvision.transforms import transforms
  7. from unet import Unet
  8. from dataset import LiverDataset
  9. # 是否使用cuda
  10. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  11. # 把多个步骤整合到一起, channel=(channel-mean)/std, 因为是分别对三个通道处理
  12. x_transforms = transforms.Compose([
  13. transforms.ToTensor(), # -> [0,1]
  14. transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # ->[-1,1]
  15. ])
  16. # mask只需要转换为tensor
  17. y_transforms = transforms.ToTensor()
  18. # 参数解析器,用来解析从终端读取的命令
  19. parse = argparse.ArgumentParser()
  20. def train_model(model, criterion, optimizer, dataload, num_epochs=20):
  21. for epoch in range(num_epochs):
  22. print('Epoch {}/{}'.format(epoch, num_epochs - 1))
  23. print('-' * 10)
  24. dt_size = len(dataload.dataset)
  25. epoch_loss = 0
  26. step = 0
  27. for x, y in dataload:
  28. step += 1
  29. inputs = x.to(device)
  30. labels = y.to(device)
  31. # zero the parameter gradients
  32. optimizer.zero_grad()
  33. # forward
  34. outputs = model(inputs)
  35. loss = criterion(outputs, labels)
  36. loss.backward()
  37. optimizer.step()
  38. epoch_loss += loss.item()
  39. print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
  40. print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
  41. torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
  42. return model
  43. # 训练模型
  44. def train():
  45. model = Unet(3, 1).to(device)
  46. batch_size = args.batch_size
  47. criterion = torch.nn.BCELoss()
  48. optimizer = optim.Adam(model.parameters())
  49. liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
  50. dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
  51. train_model(model, criterion, optimizer, dataloaders)
  52. # 显示模型的输出结果
  53. def test():
  54. model = Unet(3, 1)
  55. model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
  56. liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
  57. dataloaders = DataLoader(liver_dataset, batch_size=1)
  58. model.eval()
  59. import matplotlib.pyplot as plt
  60. plt.ion()
  61. with torch.no_grad():
  62. for x, _ in dataloaders:
  63. y=model(x)
  64. img_y=torch.squeeze(y).numpy()
  65. plt.imshow(img_y)
  66. plt.pause(0.01)
  67. plt.show()
  68. parse = argparse.ArgumentParser()
  69. # parse.add_argument("action", type=str, help="train or test")
  70. parse.add_argument("--batch_size", type=int, default=1)
  71. parse.add_argument("--ckp", type=str, help="the path of model weight file")
  72. args = parse.parse_args()
  73. # train
  74. #train()
  75. # test()
  76. args.ckp = "weights_19.pth"
  77. test()

测试结果:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/87919
推荐阅读
相关标签
  

闽ICP备14008679号