当前位置:   article > 正文

【代码复现】ResUNet++进行语义分割(含图像切片预处理)

resunet++


参考资料

文章地址:https://arxiv.org/pdf/1911.07067.pdf
代码地址:https://github.com/DebeshJha/ResUNetPlusPlus

1. preprocess.py

前言:
可能由于显卡内存不够的原因,导致尺寸很大的图片进行训练时,导致GPU显存不够的情况,一个简单的方法:对图片进行切片操作。对图片进行切片处理:将尺寸很大的图片裁剪成尺寸固定且大小适中的图片,方便后续进行训练。

该部分代码的功能:将训练集和测试集分别进行224×224裁剪,存储到新的文件夹中

1.1. 参数声明

1.1.1. 执行命令的形参

python preprocess.py --config "configs/default.yaml" --train ./DataSet_png512/train --valid ./DataSet_png512/test
  • 1

--train:训练集路径
--vaild:验证集路径
--config:配置文件,具体内容如下:

train: "./DataPreprocess/train" # 训练数据文件夹路径
valid: "./DataPreprocess/test"  # 验证数据文件夹路径
log: "logs"                     # tensorboard的events存储路径: ./logs
logging_step: 100
validation_interval: 20 # Save and valid have same interval
checkpoints: "checkpoints"

batch_size: 4
lr: 0.001
RESNET_PLUS_PLUS: True  # 使用ResUNet++模型;若该值为False则使用ResUNet模型
IMAGE_SIZE: 512         # 1500
CROP_SIZE: 224          # 224
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

1.1.2. 代码中的参数声明

if __name__ == '__main__':
    # 这部分在上面已经赋值过
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, required=True,
                        help="yaml file for configuration")
    parser.add_argument('-t', '--train', type=str, required=True,
                        help="Training Folder.")
    parser.add_argument('-v', '--valid', type=str, required=True,
                        help="Validation Folder")
    args = parser.parse_args()


    # 将--config参数赋值给hp,由hp来调用其中的参数
    hp = HParam(args.config)
    with open(args.config, 'r') as f:
        hp_str = ''.join(f.readlines())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

参数赋值:

    # 数据集路径
    train_dir = args.train   # './DataSet_png512/train'
    valid_dir = args.valid   # './DataSet_png512/test'
    
    #start_points这个函数具体作用下面介绍 
    X_points = start_points(hp.IMAGE_SIZE, hp.CROP_SIZE, 0)  # [0,192,288]
    Y_points = start_points(hp.IMAGE_SIZE, hp.CROP_SIZE, 0)  # [0,192,288]

    ## 训练集图片和掩码的文件夹路径
    train_img_dir = os.path.join(train_dir, "images") # './DataSet_png512/train/images'
    train_mask_dir = os.path.join(train_dir, "masks") # './DataSet_png512/train/masks'

    # 经过preprocess处理后图片的保存路径(如果事先没创建文件夹现在创建)
    train_img_crop_dir = os.path.join(hp.train, "images_crop") # './DataPreprocess/train/images_crop'
    os.makedirs(train_img_crop_dir, exist_ok=True)
    train_mask_crop_dir = os.path.join(hp.train, "masks_crop") # './DataPreprocess/train/masks_crop'
    os.makedirs(train_mask_crop_dir, exist_ok=True)

    # 遍历所有图片,然后打印图片数量
    img_files = glob.glob(os.path.join(train_img_dir, '**', '*.png'), recursive=True)
    mask_files = glob.glob(os.path.join(train_mask_dir, '**', '*.png'), recursive=True)
    print("Length of image :", len(img_files))
    print("Length of mask :", len(mask_files))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

     上面代码中出现的start_points()函数,得到X_points和Y_points都为0,192,288,这三个点是图片裁剪的起始点,裁剪图片大小为224×224,具体实现方法见下面crop_image_mask()函数。

def crop_image_mask(image_dir, mask_dir, mask_path, X_points, Y_points, split_height=224, split_width=224):
    img_id = os.path.basename(mask_path).split(".")[0]
    mask = load_image(mask_path)
    img = load_image(mask_path.replace("masks", "images"))

    count = 0
    num_skipped = 1
    for i in Y_points:
        for j in X_points:
        
            # img[0:224,0:244],[0:224,192:416],[0:224,288:512]
            # img[192:416,0:244],[192:416,192:416],[192:416,288:512]
            # img[288:512,0:244],[288:512,192:416],[288:512,288:512]
            new_image = img[i:i + split_height, j:j + split_width] 
            new_mask = mask[i:i + split_height, j:j + split_width]
            new_mask[new_mask > 100] = 255
            new_mask[new_mask <= 100] = 0

            # 如果白色像素点/黑色像素点<0.01,就将图片设置成全黑。
            # 这种方式不适合用作小目标分割(眼底渗出物分割不适用)
            if np.any(new_mask):
                num_black_pixels, num_white_pixels = np.unique(new_mask, return_counts=True)[1]

                if num_white_pixels / num_black_pixels < 0.01:
                    num_skipped += 1
                    continue

            mask_ = Image.fromarray(new_mask.astype(np.uint8))
            mask_.save("{}/{}_{}.jpg".format(mask_dir, img_id, count), "JPEG")
            im = Image.fromarray(new_image.astype(np.uint8))
            im.save("{}/{}_{}.jpg".format(image_dir, img_id, count), "JPEG")
            count = count + 1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

     到这里图片预处理便完成了,将训练集和测试集分别进行224×224裁剪,存储到新的文件夹中,后面train.py就是在这个新的文件夹中读取数据的。

2. train.py

2.1. 参数声明

python train.py --name "default" --config "configs/default.yaml"
  • 1

--name:1.保存权重的文件夹名称;2.保存events的文件夹名称
--config:配置文件,具体内容如下:

train: "./DataPreprocess/train" # 训练数据文件夹路径
valid: "./DataPreprocess/test"  # 验证数据文件夹路径
log: "logs"                     # tensorboard的events存储路径: ./logs
logging_step: 100
validation_interval: 20 # Save and valid have same interval
checkpoints: "checkpoints"

batch_size: 4
lr: 0.001
RESNET_PLUS_PLUS: True  # 使用ResUNet++模型;若该值为False则使用ResUNet模型
IMAGE_SIZE: 512         # 1500
CROP_SIZE: 224          # 224
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

参数声明完成后,跳到main主函数

2.2. main函数(不包括训练阶段)

2.2.1 参数说明

main(hp, num_epochs=args.epochs, resume=args.resume, name=args.name)
  • 1

hp:就是configs/default.yaml里面的参数
num_epochs:默认为 75
resume:默认空字符串‘ ’
name:字符串:‘default’

def main(hp, num_epochs, resume, name):
checkpoint_dir:'checkpoint/default'  # 保存的权重路径
writer = MyWriter("{}/{}".format(hp.log, name)) # logdir: 'log/default'
model = ResUnetPlusPlus(3).cuda()
criterion = metrics.BCEDiceLoss()  # 采用binary cross entropy 和 dice 损失
optimizer = torch.optim.Adam(model.parameters(), lr=hp.lr)  # Adam优化器
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

2.2.2. 读取数据部分

mass_dataset_train = dataloader.ImageDataset(      # 这里没有False表示对验证集进行处理
        hp, transform=transforms.Compose([dataloader.ToTensorTarget()]))
mass_dataset_val = dataloader.ImageDataset(         # 这里False表示对验证集进行处理
        hp, False, transform=transforms.Compose([dataloader.ToTensorTarget()]))
  • 1
  • 2
  • 3
  • 4

     调用dataloader.ImageDataset类,要注意的是这里读取的是经过数据预处理的图片,对应文件夹名称为DataPreprocess。

class ImageDataset(Dataset):
该代码实现功能:读取图片和掩码,将其放入sample,如果self.transform==Ture,则对sample进行self.transform。
最后返回值为sample。
  • 1
  • 2
  • 3

2.2.3. 创建 loaders

    train_dataloader = DataLoader(
        mass_dataset_train, batch_size=hp.batch_size, num_workers=2, shuffle=True)
        
    val_dataloader = DataLoader(
        mass_dataset_val, batch_size=1, num_workers=2, shuffle=False)
  • 1
  • 2
  • 3
  • 4
  • 5

2.3. 训练阶段

    step = 0
    for epoch in range(start_epoch, num_epochs):
    lr_scheduler.step()   # 更新学习率
    
    # 记录准确度和损失,后面会调用update来更新值。
    train_acc = metrics.MetricTracker()   
    train_loss = metrics.MetricTracker()   
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

载入数据,模型进行训练:

        loader = tqdm(train_dataloader, desc="training")
        for idx, data in enumerate(loader):

            # 获取输入图像和掩码
            inputs = data["sat_img"].cuda()
            labels = data["map_img"].cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, labels)  # 采用binary cross entropy 和 dice 损失,前面声明过

            # 后向传播
            loss.backward()
            optimizer.step()

            # 更新acc和loss值
            train_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0))
            train_loss.update(loss.data.item(), outputs.size(0))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

紧接着,tensorboard可视化训练阶段

            # tensorboard logging:其中,hp.logging_step=100
            if step % hp.logging_step == 0:   #每100step更新一次
                writer.log_training(train_loss.avg, train_acc.avg, step)

                # 每隔100step,进度条打印一次(tqdm)
                loader.set_description(
                    "Training Loss: {:.4f} Acc: {:.4f}".format(
                        train_loss.avg, train_acc.avg )   )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

2.4. validation阶段

这部分中的validation()函数是核心:

            # hp.validation=20
            if step % hp.validation_interval == 0:
                
                # 进入validation()函数,验证阶段
                valid_metrics = validation(
                    val_dataloader, model, criterion, writer, step )

                # checkpoint_dir:'checkpoint/default/default_checkpoint_xx.pt'  # 保存的权重文件路径
                save_path = os.path.join(
                    checkpoint_dir, "%s_checkpoint_%04d.pt" % (name, step)  )
                    
                # get最小损失,后面进行保存
                best_loss = min(valid_metrics["valid_loss"], best_loss)

                # 保存参数,保存在上面save_path中
                torch.save(
                    {
                        "step": step,
                        "epoch": epoch,
                        "arch": "ResUnet++",
                        "state_dict": model.state_dict(),
                        "best_loss": best_loss,
                        "optimizer": optimizer.state_dict(),
                    },
                    save_path, )
                print("Saved checkpoint to: %s" % save_path)

            step += 1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

validation()的实现代码:

def validation(valid_loader, model, criterion, logger, step):

    # 同上
    valid_acc = metrics.MetricTracker()
    valid_loss = metrics.MetricTracker()

    # 进入验证模式
    model.eval()

    # Iterate over data.
    for idx, data in enumerate(tqdm(valid_loader, desc="validation")):

        # get the inputs and wrap in Variable
        inputs = data["sat_img"].cuda()
        labels = data["map_img"].cuda()

        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # 更新acc和loss参数
        valid_acc.update(metrics.dice_coeff(outputs, labels), outputs.size(0))
        valid_loss.update(loss.data.item(), outputs.size(0))
        
        if idx == 0:
            logger.log_images(inputs.cpu(), labels.cpu(), outputs.cpu(), step)
            
    # 将验证阶段的acc和loss写入tensorboard
    logger.log_validation(valid_loss.avg, valid_acc.avg, step)

    print("Validation Loss: {:.4f} Acc: {:.4f}".format(valid_loss.avg, valid_acc.avg))

    # 
    model.train()
    return {"valid_loss": valid_loss.avg, "valid_acc": valid_acc.avg}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

这部分代码倒数第二行model.train()的作用是:
     在验证阶段结束后调用 model.train() 是为了将模型切换回训练模式
在深度学习中,有些层(例如 Dropout、Batch Normalization 等)在训练模式和评估模式下具有不同的行为。在训练模式下,这些层会执行特定的操作来增强模型的泛化能力和稳定性。而在评估模式下,这些层的行为会发生变化,以保持一致性和可重复性。
     总之,加上 model.train() 是为了确保模型在验证阶段结束后切换回训练模式,以保持训练和评估的行为一致。

3. 其他相关代码

3.1. model.py

ResUNet++模型框架:
在这里插入图片描述

具体实现如下:

3.1.1. res_unet_plus.py

import torch.nn as nn
import torch
from core.modules import (
    ResidualConv,
    ASPP,
    AttentionBlock,
    Upsample_,
    Squeeze_Excite_Block,
)


class ResUnetPlusPlus(nn.Module):
    def __init__(self, channel, filters=[32, 64, 128, 256, 512]):
        super(ResUnetPlusPlus, self).__init__()

        self.input_layer = nn.Sequential(
            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0]),
            nn.ReLU(),
            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
        )
        self.input_skip = nn.Sequential(
            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
        )

        self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])

        self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)

        self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])

        self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)

        self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])

        self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)

        self.aspp_bridge = ASPP(filters[3], filters[4])

        self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])
        self.upsample1 = Upsample_(2)
        self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)

        self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])
        self.upsample2 = Upsample_(2)
        self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)

        self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])
        self.upsample3 = Upsample_(2)
        self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)

        self.aspp_out = ASPP(filters[1], filters[0])

        self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1), nn.Sigmoid())

    def forward(self, x):
        x1 = self.input_layer(x) + self.input_skip(x)

        x2 = self.squeeze_excite1(x1)
        x2 = self.residual_conv1(x2)

        x3 = self.squeeze_excite2(x2)
        x3 = self.residual_conv2(x3)

        x4 = self.squeeze_excite3(x3)
        x4 = self.residual_conv3(x4)

        x5 = self.aspp_bridge(x4)

        x6 = self.attn1(x3, x5)
        x6 = self.upsample1(x6)
        x6 = torch.cat([x6, x3], dim=1)
        x6 = self.up_residual_conv1(x6)

        x7 = self.attn2(x2, x6)
        x7 = self.upsample2(x7)
        x7 = torch.cat([x7, x2], dim=1)
        x7 = self.up_residual_conv2(x7)

        x8 = self.attn3(x1, x7)
        x8 = self.upsample3(x8)
        x8 = torch.cat([x8, x1], dim=1)
        x8 = self.up_residual_conv3(x8)

        x9 = self.aspp_out(x8)
        out = self.output_layer(x9)

        return out

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
3.1.1.1. Squeeze and Excitation Units

该模块的输入是上一层的通道数,一个可设置参数reduction

class Squeeze_Excite_Block(nn.Module):
    def __init__(self, channel, reduction=16):
        super(Squeeze_Excite_Block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

这个模块有什么作用呢?文献中是这样解释的:

squeeze and excitation block与residual block堆叠在一起,以增加对不同数据集的有效泛化并提高网络的性能。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/677800
推荐阅读
相关标签
  

闽ICP备14008679号