赞
踩
源码和数据已上传至github,方便下载使用。
- import torch.nn as nn
- import torch
- from torch import autograd
-
- class DoubleConv(nn.Module):
- def __init__(self, in_ch, out_ch):
- super(DoubleConv, self).__init__()
- self.conv = nn.Sequential(
- nn.Conv2d(in_ch, out_ch, 3, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_ch, out_ch, 3, padding=1),
- nn.BatchNorm2d(out_ch),
- nn.ReLU(inplace=True)
- )
-
- def forward(self, input):
- return self.conv(input)
-
-
- class Unet(nn.Module):
- def __init__(self,in_ch,out_ch):
- super(Unet, self).__init__()
-
- self.conv1 = DoubleConv(in_ch, 64)
- self.pool1 = nn.MaxPool2d(2)
- self.conv2 = DoubleConv(64, 128)
- self.pool2 = nn.MaxPool2d(2)
- self.conv3 = DoubleConv(128, 256)
- self.pool3 = nn.MaxPool2d(2)
- self.conv4 = DoubleConv(256, 512)
- self.pool4 = nn.MaxPool2d(2)
- self.conv5 = DoubleConv(512, 1024)
- self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
- self.conv6 = DoubleConv(1024, 512)
- self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
- self.conv7 = DoubleConv(512, 256)
- self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
- self.conv8 = DoubleConv(256, 128)
- self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
- self.conv9 = DoubleConv(128, 64)
- self.conv10 = nn.Conv2d(64,out_ch, 1)
-
- def forward(self,x):
- c1=self.conv1(x)
- p1=self.pool1(c1)
- c2=self.conv2(p1)
- p2=self.pool2(c2)
- c3=self.conv3(p2)
- p3=self.pool3(c3)
- c4=self.conv4(p3)
- p4=self.pool4(c4)
- c5=self.conv5(p4)
- up_6= self.up6(c5)
- merge6 = torch.cat([up_6, c4], dim=1)
- c6=self.conv6(merge6)
- up_7=self.up7(c6)
- merge7 = torch.cat([up_7, c3], dim=1)
- c7=self.conv7(merge7)
- up_8=self.up8(c7)
- merge8 = torch.cat([up_8, c2], dim=1)
- c8=self.conv8(merge8)
- up_9=self.up9(c8)
- merge9=torch.cat([up_9,c1],dim=1)
- c9=self.conv9(merge9)
- c10=self.conv10(c9)
- out = nn.Sigmoid()(c10)
- return out
- import torch.utils.data as data
- import PIL.Image as Image
- import os
-
-
- def make_dataset(root):
- imgs=[]
- n=len(os.listdir(root))//2
- for i in range(n):
- img=os.path.join(root,"%03d.png"%i)
- mask=os.path.join(root,"%03d_mask.png"%i)
- imgs.append((img,mask))
- return imgs
-
-
- class LiverDataset(data.Dataset):
- def __init__(self, root, transform=None, target_transform=None):
- imgs = make_dataset(root)
- self.imgs = imgs
- self.transform = transform
- self.target_transform = target_transform
-
- def __getitem__(self, index):
- x_path, y_path = self.imgs[index]
- img_x = Image.open(x_path)
- img_y = Image.open(y_path)
- if self.transform is not None:
- img_x = self.transform(img_x)
- if self.target_transform is not None:
- img_y = self.target_transform(img_y)
- return img_x, img_y
-
- def __len__(self):
- return len(self.imgs)
- import numpy as np
- import torch
- import argparse
- from torch.utils.data import DataLoader
- from torch import autograd, optim
- from torchvision.transforms import transforms
- from unet import Unet
- from dataset import LiverDataset
-
-
- # 是否使用cuda
-
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- # 把多个步骤整合到一起, channel=(channel-mean)/std, 因为是分别对三个通道处理
- x_transforms = transforms.Compose([
- transforms.ToTensor(), # -> [0,1]
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # ->[-1,1]
- ])
-
- # mask只需要转换为tensor
- y_transforms = transforms.ToTensor()
-
- # 参数解析器,用来解析从终端读取的命令
- parse = argparse.ArgumentParser()
-
-
- def train_model(model, criterion, optimizer, dataload, num_epochs=20):
- for epoch in range(num_epochs):
- print('Epoch {}/{}'.format(epoch, num_epochs - 1))
- print('-' * 10)
- dt_size = len(dataload.dataset)
- epoch_loss = 0
- step = 0
- for x, y in dataload:
- step += 1
- inputs = x.to(device)
- labels = y.to(device)
- # zero the parameter gradients
- optimizer.zero_grad()
- # forward
- outputs = model(inputs)
- loss = criterion(outputs, labels)
- loss.backward()
- optimizer.step()
- epoch_loss += loss.item()
- print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
- print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
- torch.save(model.state_dict(), 'weights_%d.pth' % epoch)
- return model
-
-
- # 训练模型
- def train():
- model = Unet(3, 1).to(device)
- batch_size = args.batch_size
- criterion = torch.nn.BCELoss()
- optimizer = optim.Adam(model.parameters())
- liver_dataset = LiverDataset("data/train",transform=x_transforms,target_transform=y_transforms)
- dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
- train_model(model, criterion, optimizer, dataloaders)
-
-
- # 显示模型的输出结果
- def test():
- model = Unet(3, 1)
- model.load_state_dict(torch.load(args.ckp,map_location='cpu'))
- liver_dataset = LiverDataset("data/val", transform=x_transforms,target_transform=y_transforms)
- dataloaders = DataLoader(liver_dataset, batch_size=1)
- model.eval()
- import matplotlib.pyplot as plt
- plt.ion()
- with torch.no_grad():
- for x, _ in dataloaders:
- y=model(x)
- img_y=torch.squeeze(y).numpy()
- plt.imshow(img_y)
- plt.pause(0.01)
- plt.show()
-
-
- parse = argparse.ArgumentParser()
- # parse.add_argument("action", type=str, help="train or test")
- parse.add_argument("--batch_size", type=int, default=1)
- parse.add_argument("--ckp", type=str, help="the path of model weight file")
- args = parse.parse_args()
-
- # train
- #train()
-
- # test()
- args.ckp = "weights_19.pth"
- test()
测试结果:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。