class Conv3x3(nn.Module): def __init__(self, inputCh, outputCh): super(Conv3x3, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(inputCh, outputCh, kernel_size=3, stride=1, padding=1),#卷积核3x3,in->out nn.BatchNorm2d(outPutCh),#规范化 nn.ReLU(inplace=True),#激活函数ReLU ) self.conv2 = nn.Sequential( nn.Conv2d(outputCh, outputCh, kernel_size=3, stride=1, padding=1),#根据图,上一次的out->out nn.BatchNorm2d(outputCh), nn.ReLU(inplace=True), ) def forward(self, x):#前向传播 x = self.conv1(x) x = self.conv2(x) return x
class TransConv(nn.Module): def __init__(self, inputCh, outputCh): super(TransConv, self).__init__() self.conv = nn.Sequential( nn.ConvTranspose2d(inputCh, outputCh, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1), nn.BatchNorm2d(outputCh), nn.ReLU(inplace=True), ) def forward(self, x): x = self.conv(x) return x class UpSam(nn.Module): def __init__(self, inputCh, outputCh): super(UpSam, self).__init__() self.upconv = TransConv(inputCh, outputCh)#反卷积 self.conv = Conv3x3(2 * outputCh, outputCh)#这里用到上面写的conv操作 def forward(self, x, convfeatures): x = self.upconv(x) x = torch.cat([x, convfeatures], dim=1) x = self.conv(x) return x
class UNet(nn.Module): def __init__(self, inputCh=4, outputCh=5, size=64):#4种模态数据,拟输出5个类别(label数据0~4表示:背景、坏死组织、囊肿、肿瘤核心、整体肿瘤) super(UNet, self).__init__() channels = [] for i in range(5): channels.append((2 ** i) * size)#对应图像的size self.downLayer1 = Conv3x3(inputCh, channels[0]) self.downLayer2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0), Conv3x3(channels[0], channels[1])) self.downLayer3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0), Conv3x3(channels[1], channels[2])) self.downLayer4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0), Conv3x3(channels[2], channels[3])) self.bottomLayer = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0), Conv3x3(channels[3], channels[4])) self.upLayer1 = UpSam(channels[4], channels[3]) self.upLayer2 = UpSam(channels[3], channels[2]) self.upLayer3 = UpSam(channels[2], channels[1]) self.upLayer4 = UpSam(channels[1], channels[0]) self.outLayer = nn.Conv2d(channels[0], outputCh, kernel_size=3, stride=1, padding=1) def forward(self, x): #前半条路 x1 = self.downLayer1(x) # size(32) * 16 * W * H x2 = self.downLayer2(x1) # size(64) * 16/2 * W/2 * H/2 x3 = self.downLayer3(x2) # size(128) * 16/4 * W/4 * H/4 x4 = self.downLayer4(x3) # size(256) * 16/8 * W/8 * H/8 #最底层 x5 = self.bottomLayer(x4) # size(512) * 16/16 * W/16 * H/16 #后半条路 x = self.upLayer1(x5, x4) # size(256) * 16/8 * W/8 * H/8 x = self.upLayer2(x, x3) # size(128) * 16/4 * W/4 * H/4 x = self.upLayer3(x, x2) # size(64) * 16/2 * W/2 * H/2 x = self.upLayer4(x, x1) # size(32) * 16 * W * H x = self.outLayer(x) # outputCh(2 ) * 16 * W * H return x
if __name__ == "__main__":
net = UNet(4, 5, degree=64)
batch_size = 4
a = torch.randn(batch_size, 4, 192, 192)#随便搞点数据扔进去
b = net(a)
import sys import math import torch import torch.nn as nn class Conv3x3(nn.Module): def __init__(self, in_ch, out_ch): super(Conv3x3, self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) self.conv2 = nn.Sequential( nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): x = self.conv1(x) x = self.conv2(x) return x class TransConv(nn.Module): def __init__(self, inputCh, outputCh): super(TransConv, self).__init__() self.conv = nn.Sequential( nn.ConvTranspose2d(inputCh, outputCh, kernel_size=3, stride=2, padding=1, output_padding=1, dilation=1), nn.BatchNorm2d(outputCh), nn.ReLU(inplace=True), ) def forward(self, x): x = self.conv(x) return x class UpSam(nn.Module): def __init__(self, inputCh, outputCh): super(UpSam, self).__init__() self.upconv = TransConv(inputCh, outputCh)#反卷积 self.conv = Conv3x3(2 * outputCh, outputCh)#这里用到上面写的conv操作 def forward(self, x, convfeatures): x = self.upconv(x) x = torch.cat([x, convfeatures], dim=1) x = self.conv(x) return x class UNet2D(nn.Module): def __init__(self, in_ch=4, out_ch=2, degree=64): super(UNet2D, self).__init__() chs = [] for i in range(5): chs.append((2 ** i) * degree) self.downLayer1 = Conv3x3(in_ch, chs[0]) self.downLayer2 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0), Conv3x3(chs[0], chs[1])) self.downLayer3 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0), Conv3x3(chs[1], chs[2])) self.downLayer4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0), Conv3x3(chs[2], chs[3])) self.bottomLayer = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2, padding=0), Conv3x3(chs[3], chs[4])) self.upLayer1 = Upsam(chs[4], chs[3]) self.upLayer2 = Upsam(chs[3], chs[2]) self.upLayer3 = Upsam(chs[2], chs[1]) self.upLayer4 = Upsam(chs[1], chs[0]) self.outLayer = nn.Conv2d(chs[0], out_ch, kernel_size=3, stride=1, padding=1) def forward(self, x): x1 = self.downLayer1(x) x2 = self.downLayer2(x1) x3 = self.downLayer3(x2) x4 = self.downLayer4(x3) x5 = self.bottomLayer(x4) x = self.upLayer1(x5, x4) x = self.upLayer2(x, x3) x = self.upLayer3(x, x2) x = self.upLayer4(x, x1) x = self.outLayer(x) return x if __name__ == "__main__": net = UNet2D(4, 5, degree=64) batch_size = 4 a = torch.randn(batch_size, 4, 192, 192) b = net(a) print(b.shape)
写好网络以后,就该写读入方式,这里参考了并修改了别人的读写方式,有点点繁琐,首先需要获取HGG/LGG的全部文件夹名字,具体如何获取可以搜索“获取文件夹下的子文件夹名字”,然后pip install SimpleITK,如果下载得慢可以换清华镜像源,用SimpleITK读入nii文件也比较容易。
def load_nii_as_array(img_name):
img = sitk.ReadImage(img_name)#img_name是文件路径
nda = sitk.GetArrayFromImage(img) #返回[155,240,240]的ndarray类型
return nda
def norm_vol(data):
data = data.astype(np.float)
index = data.nonzero()#创建与data一样的掩模图,作为索引
smax = np.max(data[index])#在Data里找最大
smin = np.min(data[index])
if smax - smin == 0:#如果图像是背景全0的情况,不作归一化
return data
data[index] = (data[index] - smin * 1.0) / (smax - smin)
return data
class DataLoader19(Dataset): def __init__(self, data_dir, conf='../config/train19.conf', train=True): img_lists = [] train_config = open(conf).readlines() for data in train_config: img_lists.append(os.path.join(data_dir, data.strip('\n'))) self.data = [] self.freq = np.zeros(5) self.zero_vol = np.zeros((4, 240, 240)) count = 0 for subject in img_lists: count += 1 if count % 10 == 0: print('loading imageSets %d' %count) volume, label = DataLoader19.get_subject(subject) # 4 * 155 * 240 * 240, 155 * 240 * 240 volume = norm_vol(volume) self.freq += self.get_freq(label) if train is True: length = volume.shape[1] for i in range(length): name = subject + '=slice' + str(i) if (volume[:, i, :, :] == self.zero_vol).all(): # when training, ignore zero data continue else: self.data.append([volume[:, i, :, :], label[i, :, :], name]) else: volume = np.transpose(volume, (1, 0, 2, 3)) self.data.append([volume, label, subject]) self.freq = self.freq / np.sum(self.freq) self.weight = np.median(self.freq) / self.freq print('******** Finish loading data ********') print('******** Weight for all classes ********') print(self.weight) if train is True: print('******** Total number of 2D images is ' + str(len(self.data)) + ' **********') else: print('******** Total number of subject is ' + str(len(self.data)) + ' **********') def __getitem__(self, index): [image, label, name] = self.data[index] #获取单个数据和标签,包括文件名 image = torch.from_numpy(image).float() # Float Tensor 4, 240, 240 label = torch.from_numpy(label).float() # Float Tensor 240, 240 return image, label, name def get_subject(subject): # **************** get file **************** files = os.listdir(subject) # multi_mode_dir = [] label_dir = "" for f in files: if 'flair' in f : # if is data or 't1' in f or 't1ce' in f or 't2' in f multi_mode_dir.append(f) elif 'seg' in f: # if is label label_dir = f # ********** load 4 mode images ********** multi_mode_imgs = [] # list size :4 item size: 155 * 240 * 240 for mod_dir in multi_mode_dir: path = os.path.join(subject, mod_dir) # absolute directory img = load_nii_as_array(path)#+ '/' + mod_dir + '.nii.gz' multi_mode_imgs.append(img) # ********** get label ********** label_dir = os.path.join(subject, label_dir)# label = load_nii_as_array(label_dir) # volume = np.asarray(multi_mode_imgs) return volume, label def get_freq(self, label): class_count = np.zeros((5)) for i in range(5): a = (label == i) + 0 class_count[i] = np.sum(a) return class_count if __name__ == "__main__": vol_num = 4 data_dir = 'MICCAI_BraTS_2018_Data_Training/'#'../data_sample/' conf = 'MICCAI_BraTS_2018_Data_Training/config/valid18.config' # test for training data brats19 = DataLoader19(data_dir=data_dir, conf=conf, train=True) image2d, label2d, im_name = brats19[5] print('image size ......') print(image2d.shape) # (4, 240, 240) print('label size ......') print(label2d.shape) # (240, 240) print(im_name) name = im_name.split('/')[-1] test = DataLoader19(data_dir=data_dir, conf=conf, train=False) image_volume, label_volume, subject = test[0] print(image_volume.shape) print(label_volume.shape) print(subject)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。