赞
踩
最近在读U-Net论文时,网上看到从零构建网络模型的代码。代码足够间接,而且结构比较完整,因此记录一下学习结果。
本文重点在于如何代码的实现,对于U-Net论文中的细节未涉略,关于论文的讨论可移步。
学习的资源链接在文章末尾。
首先对于模型有一个简单的认识:
对于U-net模型的构建,主要的在于卷积层和转置卷积(下采样和上采样)的实现,以及如何实现镜像对应部分的连接。请各位读者理解U-net模型,并且牢记每一步的通道数。
按照工业上或者竞赛上常见的解决问题的步骤,主要包括数据集的获取、模型的构建、模型的训练(损失函数的选择、模型的优化)、训练结果的验证。因此接下来将从这几方面对代码进行解读。
数据集网址:Carvana Image Masking Challenge | Kaggle
百度网链接:https://pan.baidu.com/s/1bhKCyd226__fDhWbYLGPJQ 提取码:4t3y
其中需要读者根据自己的需求先训练集中分出部分的数据用作验证集。
- import os
- from PIL import Image
- from torch.utils.data import Dataset
- import numpy as np
- class CarvanaDataset(Dataset):
- def __init__(self, image_dir, mask_dir, transform=None):
- self.image_dir = image_dir
- self.mask_dir = mask_dir
- self.transform = transform
- self.images = os.listdir(image_dir)
-
- def __len__(self):
- return len(self.images)
-
- def __getitem__(self, idx):
- image_path = os.path.join(self.image_dir, self.images[idx])
- mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '_mask.gif'))
- image = np.array(Image.open(image_path).convert('RGB'))
- mask = np.array(Image.open(mask_path).convert('L'), dtype=np.float32)
- mask[mask == 255.0] = 1.0
-
- if self.transform is not None:
- augmentations = self.transform(image=image, mask=mask)
- image = augmentations['image']
- mask = augmentations['mask']
-
- return image, mask
为了后续操作比较方便,直接继承Dataset,然后返回image和对应的mask。
os.listdir(path) 返回指定路径下的文件(文件夹),在上面代码中,返回整个训练集图片对应的列表。
os.path.join() 该操作直接获得每一张图片对应的储存路径
image.open().convert(), 该函数将图片按照指定的模式转变图片,例如RGB图像,或者灰度图像。(具体的官方释义我还没找到,如果有官网的解释,请赐教)
mask[mask==255.0] = 1.0 方便后续的sigmoid()函数的计算?(存疑)
首先观察U-net模型的构建,在pool层之前,总会有有两次卷积,将原图片的通道数增加。
因此首先建立类DoubleConv。
- import torch
- import torch.nn as nn
- import torchvision.transforms.functional as TF
-
- class DoubleConv(nn.Module):
- def __init__(self,in_channels,out_channels):
- super(DoubleConv,self).__init__()
- self.conv=nn.Sequential(
- nn.Conv2d(in_channels,out_channels,3,1,1,bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels,out_channels,3,1,1,bias=False),
- nn.BatchNorm2d(out_channels),
- nn.ReLU(inplace=True),
- )
-
- def forward(self,x):
- return self.conv(x)
下一步,观察U-net模型,由于具有对称性显得格外的优雅,而且每一步的处理显得很有规律,正是因为有这样的规律,因此我们在写代码的时候可以不那么繁琐,重复的卷积-池化-卷积-池化。
- class UNET(nn.Module):
- def __init__(
- self, in_channels=3, out_channels=1, features=[64,128,256,512]
- ):
- super(UNET,self).__init__()
- self.ups = nn.ModuleList()
- self.downs = nn.ModuleList()
- self.pool = nn.MaxPool2d(kernel_size=2,stride=2)
-
- for feature in features:
- self.downs.append(DoubleConv(in_channels,feature))
- in_channels = feature
-
- for feature in reversed(features):
- self.ups.append(
- nn.ConvTranspose2d(feature*2,feature,kernel_size=2,stride=2)
- )
- self.ups.append(DoubleConv(feature*2, feature))
-
- self.bottleneck = DoubleConv(features[-1], features[-1]*2)
- self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
上述代码中,将整个u-Net模型分为卷积层(下采样)、转置卷积层(上采样)、池化层、瓶颈层以及最后的卷积层。
对于下采样阶段,使用ModelList(),然后确定每一次卷积的输入、输出通道,然后使用循环结构。
- features=[64,128,256,512]
-
- self.downs = nn.ModuleList()
-
- for feature in features:
- self.downs.append(DoubleConv(in_channels,feature))
- in_channels = feature
- def forward(self,x):
- skip_connections = []
-
- for down in self.downs:
- x = down(x)
- skip_connections.append(x)
- x = self.pool(x)
-
- x = self.bottleneck(x)
- skip_connections = skip_connections[::-1]
-
- for idx in range(0, len(self.ups) ,2):
- x = self.ups[idx](x)
- skip_connection = skip_connections[idx//2]
-
- if x.shape != skip_connection.shape:
- x=TF.resize(x,size=skip_connection.shape[2:])
-
- concat_skip = torch.cat((skip_connection,x), dim=1)
- x = self.ups[idx+1](concat_skip)
-
- return self.final_conv(x)
在前向传播的时候,需要注意的是,U-net每一层都有一个skip—connnection
skip-connections=[] ,将经过卷积的x保存到列表中,在上采样的时候进行连接
skip_connections=skip_connections[::-1], 保存顺序与使用顺序相反,因此需要反序
concat_skip=torch.cat((skip_connection, x),dim=1) 对两者进行连接
我觉得我们在写代码的时候,为什么代码结构看的比较凌乱,主要因为我们没有能够将每一个功能、操作整合起来,下面给一个具体的例子。
- def save_checkpoint(state,filename='my_checkpoint.pth.tar'):
- print('=>Saving checkpoint')
- torch.save(state, filename)
将训练模型保存起来的函数
torch.save() 官网torch.save()注释
- def load_checkpoint(checkpoint, model):
- print('=>Loading checkpoint')
- model.load_state_dict(checkpoint['state_dict'])
加载模型,可以将上次未训练完的模型再次进行训练
- def get_loader(
- train_dir,
- train_maskdir,
- val_dir,
- val_maskdir,
- batch_size,
- train_transform,
- val_transform,
- num_workers=1,
- pin_momory=True,
- ):
- train_ds = CarvanaDataset(
- image_dir=train_dir,
- mask_dir=train_maskdir,
- transform=train_transform
- )
-
- train_loader = DataLoader(
- train_ds,
- batch_size=batch_size,
- num_workers=num_workers,
- pin_memory=pin_momory,
- shuffle=True
- )
-
- val_ds = CarvanaDataset(
- image_dir=val_dir,
- mask_dir=val_maskdir,
- transform=val_transform
- )
-
- val_loader = DataLoader(
- val_ds,
- batch_size=batch_size,
- num_workers=num_workers,
- pin_memory=pin_momory,
- shuffle=False
- )
-
- return train_loader,val_loader
加载数据的常用函数,其中CarvanaDataset 自定义,也可以直接使用Dataset()
DataLoader()函数中参数:
pin_memory (bool, optional) – If True
, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn
returns a batch that is a custom type, see the example below.
超参数的确定:
- LEARNING_RATE = 1e-4
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
- IMAGE_HEIGHT = 160
- IMAGE_WIDTH = 240
- BATCH_SIZE = 16
- NUM_EPOCHS = 3
- NUM_WORKER = 2
- PIN_MEMORY = True
- LOAD_MODEL = False
- TRAIN_IMG_DIR = "data/train/"
- TRAIN_MASK_DIR = "data/train_masks/"
- VAL_IMG_DIR = "data/val/"
- VAL_MASK_DIR = "data/val_masks/"
训练函数train_fn()
- def train_fn(loader, model, optimizer, loss_fn, scaler):
- loop = tqdm(loader)
-
- for batch_idx, (data, targets) in enumerate(loop):
- data = data.to(device=DEVICE)
- targets = targets.float().unsqueeze(1).to(device=DEVICE)
-
- #forward
- '''混合精度训练'''
- with torch.cuda.amp.autocast():
- preds = model(data)
- loss = loss_fn(preds,targets)
-
- #backward
- optimizer.zero_grad()
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
-
- #update tqdm loop
- loop.set_postfix(loss=loss.item())
loop = tqdm(loader) 简单理解为快速、可扩展的python进度条
loop.set_postfix() 设置进度条的输出内容
具体关于tqdm的使用,本人未深入研究
在上面的代码中,需要注意的是前向传播和反向传播时的代码,与常见的代码不同,因为该代码引入了混合精度训练,具体的请自行查阅。
源代码:https://github.com/aladdinpersson/Machine-Learning-Collection
视频资源:https://www.youtube.com/watch?v=IHq1t7NxS8k
哔哩哔哩:【CV教程】从零开始:Pytorch图像分割教程与U-NET_哔哩哔哩_bilibili
视频中可以清楚的了解到如何从零开始构建一个模型,如何运行,在使用的过程中一些附加的功能如何实现,对于我这种小白来讲,还是大有裨益的。
而且up主的GitHub网页上还有许多其他的项目,基本上都是从零开始的,以后可以试试自己去一步步的来参加kaggle竞赛。
如果还有深度学习刚入门的小伙伴,也可以一起交流学习。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。