当前位置:   article > 正文

【代码复现】MSNet-M2SNet进行语义分割

m2snet


摘要:用MSNet和M2SNet模型去分割眼底硬渗出物

1. 参数设置

train.py脚本中需要修改的参数如下:

train_path = 'path/to/Datasets/IDRiD/TrainDataset'  # 训练集路径
savepath = './saved_model/msnet'  # 权重保存路径
mode = 'train'  # 启用训练模式
batch = 8
lr = 0.05
momen = 0.9
decay = 5e-4
epoch = 50
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

论文中说明: 不同数据集有各的自训练epochs:
polyp segmentation :50
COVID-19 Lung Infection :200
breast tumor segmentation :100
OCT layer segmentation:100

注:眼底渗出物分割我暂且用50个epochs

除此之外,还有一个utils/config.py脚本需要更改其中的参数,添加两行自己数据集代码:

import os

IDRiD_root_test = 'path/to/SmallSeg/MSNet-M2SNet/Datasets/IDRiD/TestDataset' # 测试集路径
IDRiD = os.path.join(IDRiD_root_test)
  • 1
  • 2
  • 3
  • 4

2. train.py

2.1.参数赋值

进入train.py首先对参数进行赋值,具体如下

if __name__=='__main__':
    train(dataset_medical, MSNet, LossNet)

def train(Dataset, Network, Network1):
    ## dataset
    train_path = '/path/to/Datasets/IDRiD/TrainDataset'

    # 参数赋值
    cfg = Dataset.Config(datapath=train_path, savepath='./saved_model/msnet', mode='train', batch=8, lr=0.05, momen=0.9, decay=5e-4, epoch=50)  # 额外加入了mean和std参数
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

上面代码,首先进入dataset_medical.Config()函数,该函数除了将前面的参数设置成self属性,还额外增加了两个参数:self.meanself.std,即均值方差 ,它俩具体的数值要根据数据集来设置。
dataset_medical.Config()代码如下:

class Config(object):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.mean   = np.array([[[124.55, 118.90, 102.94]]])
        self.std    = np.array([[[ 56.77,  55.97,  57.50]]])
        print('\nParameters...')
        for k, v in self.kwargs.items():
            print('%-10s: %s'%(k, v))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

2.2. 数据处理

train.py脚本中的相关代码行:

    # 数据处理方式
    data = Dataset.Data(cfg)  # 即 dataset_medical.Data(),里面含数据处理方式
    
    loader = DataLoader(data, collate_fn=data.collate, batch_size=cfg.batch, shuffle=True, num_workers=8)
    if not os.path.exists(cfg.savepath):
        os.makedirs(cfg.savepath)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

进入dataset_medical.Data()函数,含数据处理方式,代码如下:

class Data(Dataset):
    def __init__(self, cfg):
        self.cfg = cfg

        # 数据处理方式
        self.normalize  = Normalize(mean=cfg.mean, std=cfg.std) # 将dataset_medical.Config()里的mean,std传入这里
        self.randomcrop = RandomCrop()
        self.randomflip = RandomFlip()
        self.randomrotate = RandomRotate()
        self.resize     = Resize(352, 352)
        self.totensor   = ToTensor()

        self.root = cfg.datapath  # 训练集的根目录

        img_path = os.path.join(self.root, 'image')  # 训练img目录的路径
        gt_path = os.path.join(self.root, 'mask')    # 训练mask目录的路径
        self.samples = [os.path.splitext(f)[0]       # 取所有图片的名称,为后面调用
                    for f in os.listdir(gt_path) if f.endswith('.png')]
                    
    def __getitem__(self, idx):
        name  = self.samples[idx]
        image = cv2.imread(self.root+'/image/'+name+'.jpg')[:,:,::-1].astype(np.float32) # [:,:,::-1]:BGR->RGB
        mask  = cv2.imread(self.root+'/mask/' +name+'.png', 0).astype(np.float32)  # 0:灰度图读取


        shape = mask.shape

        if self.cfg.mode=='train':
            image, mask = self.normalize(image, mask)
            image, mask = self.resize(image, mask)
            # image, mask = self.randomcrop(image, mask)
            image, mask = self.randomflip(image, mask)
            image, mask = self.randomrotate(image, mask)
            return image, mask
        else:
            image, mask = self.normalize(image, mask)
            image, mask = self.resize(image, mask)
            image, mask = self.totensor(image, mask)
            return image, mask, shape, name

    def collate(self, batch):
        size = [224, 256, 288, 320, 352][np.random.randint(0, 5)]
        image, mask = [list(item) for item in zip(*batch)]
        for i in range(len(batch)):
            image[i] = cv2.resize(image[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
            mask[i]  = cv2.resize(mask[i],  dsize=(size, size), interpolation=cv2.INTER_LINEAR)
        image = torch.from_numpy(np.stack(image, axis=0)).permute(0, 3, 1, 2)
        mask = torch.from_numpy(np.stack(mask, axis=0)).unsqueeze(1)
        return image, mask

    def __len__(self):
        return len(self.samples)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52

2.3. 定义损失网络和模型网络

模型定义:

    net = Network()    # 模型网络:MSNet 或 M2SNet两种可选,默认MSNet
    net1 = Network1()  # 损失网络(LossNet):vgg16
    net.train(True)    # 训练模式
    net1.eval()        # 验证模式
    net.cuda()
    net1.cuda()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

2.3.1. 损失网络:LossNet

损失网络采用vgg16模型:

class LossNet(torch.nn.Module):
    def __init__(self, resize=True):
        super(LossNet, self).__init__()
        
        # 取vgg16的前4层
        blocks = []
        blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
        print(blocks)
        
        # 遍历4个block的所有层,使它们不参与梯度计算
        for bl in blocks:
            for p in bl:
                p.requires_grad = False

        # 组合blocks
        self.blocks = torch.nn.ModuleList(blocks)
        
        # 数据处理方式,后面做loss计算会用到
        self.transform = torch.nn.functional.interpolate
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
        self.resize = resize

    # 前向传播
    def forward(self, input, target):

        # 灰度图掩码是单通道,要对其进行三通道复制
        if input.shape[1] != 3:
            input = input.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
         
        # 通过 mean 和 std 对图像归一化
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std

        # 默认self.reisze 为 True
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(512, 512), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(512, 512), align_corners=False)
            
        loss = 0.0
        x = input
        y = target

        # 损失计算:blocks里有4个block,分别进行loss计算,获取多尺度信息(可参考论文中的图片说明)
        for block in self.blocks:
            x = block(x)
            y = block(y)
            loss += torch.nn.functional.mse_loss(x, y)  # 4个loss相加
        return loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53

vgg损失函数的图片可视化说明:
在这里插入图片描述

然后,在train.py中对损失网络进行如下操作:

    # 是否采用加速训练
    torch.backends.cudnn.enabled = False  # res2net does not support cudnn in py1.7
    
    # 将LossNet参数不参与梯度计算
    for param in net1.parameters():
        param.requires_grad = False
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

对模型网络的参数处理:

    # 将模型的参数划分为head和base两部分,后续送入优化器进行优化
    base, head = [], []
    for name, param in net.named_parameters():
        if 'bkbone.conv1' in name or 'bkbone.bn1' in name:
            print(name)
        elif 'bkbone' in name:
            base.append(param)
        else:
            head.append(param)
    # 经过这一操作后,base为null,head有一些param
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

优化器设置:
[{'params':base}, {'params':head}]:这是一个列表,其中包含两个字典。每个字典都包含一个params键和对应的值,表示需要优化的模型参数。第一个字典表示模型的base部分,第二个字典表示模型的head部分
由于在上面遍历模型参数时已经将其分为不同部分,因此在这里可以分别对不同部分的参数设置不同的优化方式。

    optimizer = torch.optim.SGD([{'params':base}, {'params':head}], lr=cfg.lr, momentum=cfg.momen, weight_decay=cfg.decay, nesterov=True) # 由于base的只为null,故只对head优化,具体优化方式:个人的猜想是计算梯度。
  • 1

2.4. 训练epoch

训练过程的损失函数除了采用LossNet,还将加权二值交叉熵损失和加权交并比损失相加,并求取平均值作为最终的结构损失值,其函数名称:structure_loss(pred, mask)

    # global_step 是一个全局变量,用于记录总的训练步数。每执行一次训练步骤(backward + optimizer.step),该变量增加 1
    global_step    = 0
    
    for epoch in range(cfg.epoch):

        # 这里设置两个学习率分别对应上面base和head两个参数
        optimizer.param_groups[0]['lr'] = (1-abs((epoch+1)/(cfg.epoch+1)*2-1))*cfg.lr*0.1
        #  optimizer.param_groups[0]:这是优化器中第一个参数组,也就是 base 参数组。
        # (1-abs((epoch+1)/(cfg.epoch+1)*2-1)):这是一个动态生成的学习率因子,根据当前训练轮数和总轮数的比例计算得到。其值在 0 到 1 之间变化,表现为一个以轮数为中心对称的尖锐三角函数。
        # cfg.lr:这是配置文件中指定的基础学习率。
        # *0.1:这是一个缩放因子,将学习率缩小一个数量级
        optimizer.param_groups[1]['lr'] = (1-abs((epoch+1)/(cfg.epoch+1)*2-1))*cfg.lr
        # optimizer.param_groups[1]:这是优化器中第二个参数组,也就是 head_params 参数组。

        # 开始训练
        for step, (image, mask) in enumerate(loader):
        
            image, mask = image.cuda().float(), mask.cuda().float()
            with amp.autocast(enabled=use_fp16):     # 采用混合精度计算
                output = net(image)
                loss2u = net1(F.sigmoid(output), mask)
                loss1u = structure_loss(output, mask) # 将加权二值交叉熵损失和加权交并比损失相加,并求取平均值作为最终的结构损失值。
                loss = loss1u + 0.1 * loss2u   # 最终损失计算公式
                
            optimizer.zero_grad()   # 将梯度缓存清零,以准备下一次反向传播计算

            # scaler 是 NVIDIA Apex 库提供的混合精度训练工具。
            # scaler.scale(loss) 首先将损失值 loss 乘以一个缩放因子,以将梯度的计算结果映射为浮点 16 位(FP16)格式。
            # .backward() 用于执行反向传播操作,计算梯度。
            scaler.scale(loss).backward()
            scaler.step(optimizer)   # 更新模型参数
            scaler.update()          # 用于更新缩放因子,以确保在训练期间动态地调整精度缩放因子(scale factor),有助于防止 FP16 精度丢失。
            
            global_step += 1
            
            if step %10 == 0:   # step:当前训练过程中的步骤数
                print('%s | step:%d/%d/%d | lr=%.6f | loss1u=%.6f | loss2u=%.6f '%(datetime.datetime.now(), global_step, epoch+1, cfg.epoch, optimizer.param_groups[0]['lr'], loss1u.item(), loss2u.item()))
            # global_step:记录的是总的训练步数
            # epoch+1:表示当前轮数(epoch)加 1
            # cfg.epoch:表示总的轮数
            
        # 判断当前的训练轮数是否超过了总轮数的2/3,并在满足条件时保存模型的参数    
        if epoch>cfg.epoch/3*2:
            torch.save(net.state_dict(), cfg.savepath+'/model-'+str(epoch+1))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

至此,train.py结束。

其中,结构损失函数structure_loss(pred, mask)代码如下:

def structure_loss(pred, mask):
    weit  = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask)
    wbce  = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
    wbce  = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))

    pred  = torch.sigmoid(pred)
    inter = ((pred*mask)*weit).sum(dim=(2,3))
    union = ((pred+mask)*weit).sum(dim=(2,3))
    wiou  = 1-(inter+1)/(union-inter+1)
    return (wbce+wiou).mean()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

3. test.py

test.py在代码中对应为prediction_rgb.py,具体代码如下:

3.1. 模型设置 & 其他参数设置

首先选择要用的模型,然后载入训练权重路径,相关参数如下
ckpt_path = ‘./saved_model’
exp_name = ‘msnet’
args = { ‘snapshot’: ‘model-50’, ‘crf_refine’: False, ‘save_results’: True }

if __name__ == '__main__':
    main()

def main():

    # 选择模型MSNet
    net = MSNet().cuda()
    
    # 打印使用的权重文件  
    print ('load snapshot \'%s\' for testing' % args['snapshot'])

    # 载入权重路径: './saved_model/msnet/model-50'
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']),map_location={'cuda:1': 'cuda:1'}))
    # map_location={'cuda:1': 'cuda:1'} 参数指定了模型参数在哪个设备上进行加载。
    # 'cuda:1' 表示将模型参数加载到 CUDA 设备的第一个索引上(即 GPU 设备)。
    # 如果没有指定该参数,模型参数将默认加载到 CPU 上

    # 启用验证模式
    net.eval()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

3.2. eval部分

3.2.1. 相关参数设置和图片预处理

(1)其中,字典 to_test:{IDRiD:‘/home_lv/guanyu.zhu/python/SmallSeg/MSNet-M2SNet/Datasets/IDRiD/TestDataset’}
(2)这段代码输入图片的原尺寸[4288,2848],然后resize成[512,512],很可能导致预测效果太差,后续还要换一种适合的方法。

    with torch.no_grad():
        
        # name:IDRiD ; 
        # root:'/home_lv/guanyu.zhu/python/SmallSeg/MSNet-M2SNet/Datasets/IDRiD/TestDataset'
        for name, root in to_test.items():
            print(root)

            # 检查/创建文件夹:'./saved_model/msnet/(msnet)IDRiD_model-50'
            check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))

            # img文件夹路径
            root1 = os.path.join(root,'image')   # '/home_lv/guanyu.zhu/python/SmallSeg/MSNet-M2SNet/Datasets/IDRiD/TestDataset/image'
            img_list = [os.path.splitext(f) for f in os.listdir(root1)]  # 每个img图片的名称,如 'IDRiD_75.jpg'

            # idx:图片的索引值 ;  img_name:tuple('图片名称', '.后缀名')
            for idx, img_name in enumerate(img_list):
                # 打印开始处理第几张图片
                print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
                # 读取图片
                img = Image.open(os.path.join(root,'image',img_name[0]+img_name[1])).convert('RGB')
                w_,h_ = img.size  # w_: 4288 ; h_: 2848  
                img_resize = img.resize([512,512], Image.BILINEAR)  

                # 上一句代码将图片resize成[512,512]
                # 该处经过函数:img_transform对图片进行ToTensor和Normalize操作
                img_var = Variable(img_transform(img_resize).unsqueeze(0), volatile=True).cuda()
                n, c, h, w = img_var.size()        # n:1     w:512    c:3    h:512
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

3.2.2. 验证过程

读取完图片对其进行一个增强操作transformer.augment_image(),经过net网络后得到预测输出model_output,它再经过transformer.deaugment_mask()处理得到deaug_mask作为最后的预测结果。

                mask = []
                for transformer in transforms:  # custom transforms or e.g. tta.aliases.d4_transform()

                    rgb_trans = transformer.augment_image(img_var)
                    model_output = net(rgb_trans)
                    deaug_mask = transformer.deaugment_mask(model_output)
                    mask.append(deaug_mask)

                prediction = torch.mean(torch.stack(mask, dim=0), dim=0) 
                prediction = prediction.sigmoid()                        
                prediction = to_pil(prediction.data.squeeze(0).cpu())    # 512*512
                prediction = prediction.resize((w_, h_), Image.BILINEAR) # resize(4288,2848)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

3.3. 结果保存

保存预测结果图:文件夹路径:‘./saved_model/msnet/model-50epoch/IDRiD’
代码如下:

                if args['save_results']:
                    check_mkdir(os.path.join(ckpt_path, exp_name,args['snapshot']+'epoch',name))
                    prediction.save(os.path.join(ckpt_path, exp_name ,args['snapshot']+'epoch',name, img_name[0] + '.png'))
  • 1
  • 2
  • 3

4. MSNet模型图

在这里插入图片描述

5. M2SNet模型

5.1. 总体代码

class M2SNet(nn.Module):
    # res2net based encoder decoder
    def __init__(self):
        super(M2SNet, self).__init__()
        # ---- ResNet Backbone ----
        self.resnet = res2net50_v1b_26w_4s(pretrained=True)
        self.conv_3 = CNN1(64,3,1)
        self.conv_5 = CNN1(64, 5, 2)


        self.x5_dem_1 = nn.Sequential(nn.Conv2d(2048, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x4_dem_1 = nn.Sequential(nn.Conv2d(1024, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x3_dem_1 = nn.Sequential(nn.Conv2d(512, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x2_dem_1 = nn.Sequential(nn.Conv2d(256, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x5_x4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x4_x3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x3_x2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x2_x1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))

        self.x5_x4_x3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x4_x3_x2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x3_x2_x1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))

        self.x5_x4_x3_x2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x4_x3_x2_x1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
                                         nn.ReLU(inplace=True))
        self.x5_dem_4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x5_x4_x3_x2_x1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
                                         nn.ReLU(inplace=True))

        self.level3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.level2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.level1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.x5_dem_5 = nn.Sequential(nn.Conv2d(2048, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
                                      nn.ReLU(inplace=True))
        self.output4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.output3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.output2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.output1 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=3, padding=1))

    def forward(self, x):
        input = x                      # [b,3,512,512]
        # '''
        x = self.resnet.conv1(x)       # [b,64,256,256]
        x = self.resnet.bn1(x)         # [b,64,256,256]
        x = self.resnet.relu(x)        # [b,64,256,256]
        x1 = self.resnet.maxpool(x)    # [b,64,128,128]
        # ---- low-level features ----
        x2 = self.resnet.layer1(x1)     # [b, 256,  128, 128]
        x3 = self.resnet.layer2(x2)     # [b, 512,  64,  64]
        x4 = self.resnet.layer3(x3)     # [b, 1024, 32,  32]
        x5 = self.resnet.layer4(x4)     # [b, 2048, 16,  16]
        # '''


        x5_dem_1 = self.x5_dem_1(x5)
        x4_dem_1 = self.x4_dem_1(x4)
        x3_dem_1 = self.x3_dem_1(x3)
        x2_dem_1 = self.x2_dem_1(x2)
        
        # 多尺度减法单元
        x5_dem_1_up = F.upsample(x5_dem_1, size=x4.size()[2:], mode='bilinear')  # 将x5_dem_1上采样到x4的h和w
        x5_dem_1_up_map1 = self.conv_3(x5_dem_1_up)   # [b,64,32,32]   through self.conv_3,size of feature_map retain
        x4_dem_1_map1 = self.conv_3(x4_dem_1)         # [b,64,32,32]
        x5_dem_1_up_map2 = self.conv_5(x5_dem_1_up)   # [b,64,32,32] -> [b,64,32,32]
        x4_dem_1_map2 = self.conv_5(x4_dem_1)         # [b,64,32,32]
        x5_4 = self.x5_x4(
            abs(x5_dem_1_up - x4_dem_1)+abs(x5_dem_1_up_map1-x4_dem_1_map1)+abs(x5_dem_1_up_map2-x4_dem_1_map2))


        x4_dem_1_up = F.upsample(x4_dem_1, size=x3.size()[2:], mode='bilinear')
        x4_dem_1_up_map1 = self.conv_3(x4_dem_1_up)
        x3_dem_1_map1 = self.conv_3(x3_dem_1)
        x4_dem_1_up_map2 = self.conv_5(x4_dem_1_up)
        x3_dem_1_map2 = self.conv_5(x3_dem_1)
        x4_3 = self.x4_x3(
            abs(x4_dem_1_up - x3_dem_1)+abs(x4_dem_1_up_map1-x3_dem_1_map1)+abs(x4_dem_1_up_map2-x3_dem_1_map2) )


        x3_dem_1_up = F.upsample(x3_dem_1, size=x2.size()[2:], mode='bilinear')
        x3_dem_1_up_map1 = self.conv_3(x3_dem_1_up)
        x2_dem_1_map1 = self.conv_3(x2_dem_1)
        x3_dem_1_up_map2 = self.conv_5(x3_dem_1_up)
        x2_dem_1_map2 = self.conv_5(x2_dem_1)
        x3_2 = self.x3_x2(
            abs(x3_dem_1_up - x2_dem_1)+abs(x3_dem_1_up_map1-x2_dem_1_map1)+abs(x3_dem_1_up_map2-x2_dem_1_map2) )


        x2_dem_1_up = F.upsample(x2_dem_1, size=x1.size()[2:], mode='bilinear')
        x2_dem_1_up_map1 = self.conv_3(x2_dem_1_up)
        x1_map1 = self.conv_3(x1)
        x2_dem_1_up_map2 = self.conv_5(x2_dem_1_up)
        x1_map2 = self.conv_5(x1)
        x2_1 = self.x2_x1(abs(x2_dem_1_up - x1)+abs(x2_dem_1_up_map1-x1_map1)+abs(x2_dem_1_up_map2-x1_map2) )


        x5_4_up = F.upsample(x5_4, size=x4_3.size()[2:], mode='bilinear')
        x5_4_up_map1 = self.conv_3(x5_4_up)
        x4_3_map1 = self.conv_3(x4_3)
        x5_4_up_map2 = self.conv_5(x5_4_up)
        x4_3_map2 = self.conv_5(x4_3)
        x5_4_3 = self.x5_x4_x3(abs(x5_4_up - x4_3) +abs(x5_4_up_map1-x4_3_map1)+abs(x5_4_up_map2-x4_3_map2))


        x4_3_up = F.upsample(x4_3, size=x3_2.size()[2:], mode='bilinear')
        x4_3_up_map1 = self.conv_3(x4_3_up)
        x3_2_map1 = self.conv_3(x3_2)
        x4_3_up_map2 = self.conv_5(x4_3_up)
        x3_2_map2 = self.conv_5(x3_2)
        x4_3_2 = self.x4_x3_x2(abs(x4_3_up - x3_2)+abs(x4_3_up_map1-x3_2_map1)+abs(x4_3_up_map2-x3_2_map2) )


        x3_2_up = F.upsample(x3_2, size=x2_1.size()[2:], mode='bilinear')
        x3_2_up_map1 = self.conv_3(x3_2_up)
        x2_1_map1 = self.conv_3(x2_1)
        x3_2_up_map2 = self.conv_5(x3_2_up)
        x2_1_map2 = self.conv_5(x2_1)
        x3_2_1 = self.x3_x2_x1(abs(x3_2_up - x2_1)+abs(x3_2_up_map1-x2_1_map1)+abs(x3_2_up_map2-x2_1_map2) )


        x5_4_3_up = F.upsample(x5_4_3, size=x4_3_2.size()[2:], mode='bilinear')
        x5_4_3_up_map1 = self.conv_3(x5_4_3_up)
        x4_3_2_map1 = self.conv_3(x4_3_2)
        x5_4_3_up_map2 = self.conv_5(x5_4_3_up)
        x4_3_2_map2 = self.conv_5(x4_3_2)
        x5_4_3_2 = self.x5_x4_x3_x2(
            abs(x5_4_3_up - x4_3_2)+abs(x5_4_3_up_map1-x4_3_2_map1)+abs(x5_4_3_up_map2-x4_3_2_map2) )


        x4_3_2_up = F.upsample(x4_3_2, size=x3_2_1.size()[2:], mode='bilinear')
        x4_3_2_up_map1 = self.conv_3(x4_3_2_up)
        x3_2_1_map1 = self.conv_3(x3_2_1)
        x4_3_2_up_map2 = self.conv_5(x4_3_2_up)
        x3_2_1_map2 = self.conv_5(x3_2_1)
        x4_3_2_1 = self.x4_x3_x2_x1(
            abs(x4_3_2_up - x3_2_1) +abs(x4_3_2_up_map1-x3_2_1_map1)+abs(x4_3_2_up_map2-x3_2_1_map2))


        x5_dem_4 = self.x5_dem_4(x5_4_3_2)
        x5_dem_4_up = F.upsample(x5_dem_4, size=x4_3_2_1.size()[2:], mode='bilinear')
        x5_dem_4_up_map1 = self.conv_3(x5_dem_4_up)
        x4_3_2_1_map1 = self.conv_3(x4_3_2_1)
        x5_dem_4_up_map2 = self.conv_5(x5_dem_4_up)
        x4_3_2_1_map2 = self.conv_5(x4_3_2_1)
        x5_4_3_2_1 = self.x5_x4_x3_x2_x1(
            abs(x5_dem_4_up - x4_3_2_1)+abs(x5_dem_4_up_map1-x4_3_2_1_map1)+abs(x5_dem_4_up_map2-x4_3_2_1_map2) )

        level4 = x5_4
        level3 = self.level3(x4_3 + x5_4_3)
        level2 = self.level2(x3_2 + x4_3_2 + x5_4_3_2)
        level1 = self.level1(x2_1 + x3_2_1 + x4_3_2_1 + x5_4_3_2_1)

        x5_dem_5 = self.x5_dem_5(x5)
        output4 = self.output4(F.upsample(x5_dem_5,size=level4.size()[2:], mode='bilinear') + level4)
        output3 = self.output3(F.upsample(output4,size=level3.size()[2:], mode='bilinear') + level3)
        output2 = self.output2(F.upsample(output3,size=level2.size()[2:], mode='bilinear') + level2)
        output1 = self.output1(F.upsample(output2,size=level1.size()[2:], mode='bilinear') + level1)
        output = F.upsample(output1, size=input.size()[2:], mode='bilinear')
        if self.training:
            return output
        return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161

模型总体框架如下:
在这里插入图片描述

5.2. Multi-scale subtraction unit

从上面代码中摘取多尺度减法单元部分,如下:

        # 多尺度减法单元
        x5_dem_1_up = F.upsample(x5_dem_1, size=x4.size()[2:], mode='bilinear')  # 将x5_dem_1上采样到x4的h和w
        x5_dem_1_up_map1 = self.conv_3(x5_dem_1_up)   # [b,64,32,32]   through self.conv_3,size of feature_map retain
        x4_dem_1_map1 = self.conv_3(x4_dem_1)         # [b,64,32,32]
        x5_dem_1_up_map2 = self.conv_5(x5_dem_1_up)   # [b,64,32,32] -> [b,64,32,32]
        x4_dem_1_map2 = self.conv_5(x4_dem_1)         # [b,64,32,32]
        x5_4 = self.x5_x4(
            abs(x5_dem_1_up - x4_dem_1)+abs(x5_dem_1_up_map1-x4_dem_1_map1)+abs(x5_dem_1_up_map2-x4_dem_1_map2))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

可视化模块图如下:

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/677808
推荐阅读
相关标签
  

闽ICP备14008679号