赞
踩
摘要:用MSNet和M2SNet模型去分割眼底硬渗出物
在train.py
脚本中需要修改的参数如下:
train_path = 'path/to/Datasets/IDRiD/TrainDataset' # 训练集路径
savepath = './saved_model/msnet' # 权重保存路径
mode = 'train' # 启用训练模式
batch = 8
lr = 0.05
momen = 0.9
decay = 5e-4
epoch = 50
论文中说明: 不同数据集有各的自训练epochs:
polyp segmentation :50
COVID-19 Lung Infection :200
breast tumor segmentation :100
OCT layer segmentation:100
注:眼底渗出物分割我暂且用50个epochs
除此之外,还有一个utils/config.py
脚本需要更改其中的参数,添加两行自己数据集代码:
import os
IDRiD_root_test = 'path/to/SmallSeg/MSNet-M2SNet/Datasets/IDRiD/TestDataset' # 测试集路径
IDRiD = os.path.join(IDRiD_root_test)
进入train.py
首先对参数进行赋值,具体如下
if __name__=='__main__':
train(dataset_medical, MSNet, LossNet)
def train(Dataset, Network, Network1):
## dataset
train_path = '/path/to/Datasets/IDRiD/TrainDataset'
# 参数赋值
cfg = Dataset.Config(datapath=train_path, savepath='./saved_model/msnet', mode='train', batch=8, lr=0.05, momen=0.9, decay=5e-4, epoch=50) # 额外加入了mean和std参数
上面代码,首先进入dataset_medical.Config()
函数,该函数除了将前面的参数设置成self
属性,还额外增加了两个参数:self.mean
和self.std
,即均值和方差 ,它俩具体的数值要根据数据集来设置。
dataset_medical.Config()
代码如下:
class Config(object):
def __init__(self, **kwargs):
self.kwargs = kwargs
self.mean = np.array([[[124.55, 118.90, 102.94]]])
self.std = np.array([[[ 56.77, 55.97, 57.50]]])
print('\nParameters...')
for k, v in self.kwargs.items():
print('%-10s: %s'%(k, v))
在train.py
脚本中的相关代码行:
# 数据处理方式
data = Dataset.Data(cfg) # 即 dataset_medical.Data(),里面含数据处理方式
loader = DataLoader(data, collate_fn=data.collate, batch_size=cfg.batch, shuffle=True, num_workers=8)
if not os.path.exists(cfg.savepath):
os.makedirs(cfg.savepath)
进入dataset_medical.Data()
函数,含数据处理方式,代码如下:
class Data(Dataset):
def __init__(self, cfg):
self.cfg = cfg
# 数据处理方式
self.normalize = Normalize(mean=cfg.mean, std=cfg.std) # 将dataset_medical.Config()里的mean,std传入这里
self.randomcrop = RandomCrop()
self.randomflip = RandomFlip()
self.randomrotate = RandomRotate()
self.resize = Resize(352, 352)
self.totensor = ToTensor()
self.root = cfg.datapath # 训练集的根目录
img_path = os.path.join(self.root, 'image') # 训练img目录的路径
gt_path = os.path.join(self.root, 'mask') # 训练mask目录的路径
self.samples = [os.path.splitext(f)[0] # 取所有图片的名称,为后面调用
for f in os.listdir(gt_path) if f.endswith('.png')]
def __getitem__(self, idx):
name = self.samples[idx]
image = cv2.imread(self.root+'/image/'+name+'.jpg')[:,:,::-1].astype(np.float32) # [:,:,::-1]:BGR->RGB
mask = cv2.imread(self.root+'/mask/' +name+'.png', 0).astype(np.float32) # 0:灰度图读取
shape = mask.shape
if self.cfg.mode=='train':
image, mask = self.normalize(image, mask)
image, mask = self.resize(image, mask)
# image, mask = self.randomcrop(image, mask)
image, mask = self.randomflip(image, mask)
image, mask = self.randomrotate(image, mask)
return image, mask
else:
image, mask = self.normalize(image, mask)
image, mask = self.resize(image, mask)
image, mask = self.totensor(image, mask)
return image, mask, shape, name
def collate(self, batch):
size = [224, 256, 288, 320, 352][np.random.randint(0, 5)]
image, mask = [list(item) for item in zip(*batch)]
for i in range(len(batch)):
image[i] = cv2.resize(image[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
mask[i] = cv2.resize(mask[i], dsize=(size, size), interpolation=cv2.INTER_LINEAR)
image = torch.from_numpy(np.stack(image, axis=0)).permute(0, 3, 1, 2)
mask = torch.from_numpy(np.stack(mask, axis=0)).unsqueeze(1)
return image, mask
def __len__(self):
return len(self.samples)
模型定义:
net = Network() # 模型网络:MSNet 或 M2SNet两种可选,默认MSNet
net1 = Network1() # 损失网络(LossNet):vgg16
net.train(True) # 训练模式
net1.eval() # 验证模式
net.cuda()
net1.cuda()
损失网络采用vgg16模型:
class LossNet(torch.nn.Module):
def __init__(self, resize=True):
super(LossNet, self).__init__()
# 取vgg16的前4层
blocks = []
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
print(blocks)
# 遍历4个block的所有层,使它们不参与梯度计算
for bl in blocks:
for p in bl:
p.requires_grad = False
# 组合blocks
self.blocks = torch.nn.ModuleList(blocks)
# 数据处理方式,后面做loss计算会用到
self.transform = torch.nn.functional.interpolate
self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
self.resize = resize
# 前向传播
def forward(self, input, target):
# 灰度图掩码是单通道,要对其进行三通道复制
if input.shape[1] != 3:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
# 通过 mean 和 std 对图像归一化
input = (input-self.mean) / self.std
target = (target-self.mean) / self.std
# 默认self.reisze 为 True
if self.resize:
input = self.transform(input, mode='bilinear', size=(512, 512), align_corners=False)
target = self.transform(target, mode='bilinear', size=(512, 512), align_corners=False)
loss = 0.0
x = input
y = target
# 损失计算:blocks里有4个block,分别进行loss计算,获取多尺度信息(可参考论文中的图片说明)
for block in self.blocks:
x = block(x)
y = block(y)
loss += torch.nn.functional.mse_loss(x, y) # 4个loss相加
return loss
vgg损失函数的图片可视化说明:
然后,在train.py
中对损失网络进行如下操作:
# 是否采用加速训练
torch.backends.cudnn.enabled = False # res2net does not support cudnn in py1.7
# 将LossNet参数不参与梯度计算
for param in net1.parameters():
param.requires_grad = False
对模型网络的参数处理:
# 将模型的参数划分为head和base两部分,后续送入优化器进行优化
base, head = [], []
for name, param in net.named_parameters():
if 'bkbone.conv1' in name or 'bkbone.bn1' in name:
print(name)
elif 'bkbone' in name:
base.append(param)
else:
head.append(param)
# 经过这一操作后,base为null,head有一些param
优化器设置:
[{'params':base}, {'params':head}]
:这是一个列表,其中包含两个字典。每个字典都包含一个params键和对应的值,表示需要优化的模型参数。第一个字典表示模型的base部分
,第二个字典表示模型的head部分
。
由于在上面遍历模型参数时已经将其分为不同部分,因此在这里可以分别对不同部分的参数设置不同的优化方式。
optimizer = torch.optim.SGD([{'params':base}, {'params':head}], lr=cfg.lr, momentum=cfg.momen, weight_decay=cfg.decay, nesterov=True) # 由于base的只为null,故只对head优化,具体优化方式:个人的猜想是计算梯度。
训练过程的损失函数除了采用LossNet,还将加权二值交叉熵损失和加权交并比损失相加,并求取平均值作为最终的结构损失值,其函数名称:structure_loss(pred, mask)
。
# global_step 是一个全局变量,用于记录总的训练步数。每执行一次训练步骤(backward + optimizer.step),该变量增加 1
global_step = 0
for epoch in range(cfg.epoch):
# 这里设置两个学习率分别对应上面base和head两个参数
optimizer.param_groups[0]['lr'] = (1-abs((epoch+1)/(cfg.epoch+1)*2-1))*cfg.lr*0.1
# optimizer.param_groups[0]:这是优化器中第一个参数组,也就是 base 参数组。
# (1-abs((epoch+1)/(cfg.epoch+1)*2-1)):这是一个动态生成的学习率因子,根据当前训练轮数和总轮数的比例计算得到。其值在 0 到 1 之间变化,表现为一个以轮数为中心对称的尖锐三角函数。
# cfg.lr:这是配置文件中指定的基础学习率。
# *0.1:这是一个缩放因子,将学习率缩小一个数量级
optimizer.param_groups[1]['lr'] = (1-abs((epoch+1)/(cfg.epoch+1)*2-1))*cfg.lr
# optimizer.param_groups[1]:这是优化器中第二个参数组,也就是 head_params 参数组。
# 开始训练
for step, (image, mask) in enumerate(loader):
image, mask = image.cuda().float(), mask.cuda().float()
with amp.autocast(enabled=use_fp16): # 采用混合精度计算
output = net(image)
loss2u = net1(F.sigmoid(output), mask)
loss1u = structure_loss(output, mask) # 将加权二值交叉熵损失和加权交并比损失相加,并求取平均值作为最终的结构损失值。
loss = loss1u + 0.1 * loss2u # 最终损失计算公式
optimizer.zero_grad() # 将梯度缓存清零,以准备下一次反向传播计算
# scaler 是 NVIDIA Apex 库提供的混合精度训练工具。
# scaler.scale(loss) 首先将损失值 loss 乘以一个缩放因子,以将梯度的计算结果映射为浮点 16 位(FP16)格式。
# .backward() 用于执行反向传播操作,计算梯度。
scaler.scale(loss).backward()
scaler.step(optimizer) # 更新模型参数
scaler.update() # 用于更新缩放因子,以确保在训练期间动态地调整精度缩放因子(scale factor),有助于防止 FP16 精度丢失。
global_step += 1
if step %10 == 0: # step:当前训练过程中的步骤数
print('%s | step:%d/%d/%d | lr=%.6f | loss1u=%.6f | loss2u=%.6f '%(datetime.datetime.now(), global_step, epoch+1, cfg.epoch, optimizer.param_groups[0]['lr'], loss1u.item(), loss2u.item()))
# global_step:记录的是总的训练步数
# epoch+1:表示当前轮数(epoch)加 1
# cfg.epoch:表示总的轮数
# 判断当前的训练轮数是否超过了总轮数的2/3,并在满足条件时保存模型的参数
if epoch>cfg.epoch/3*2:
torch.save(net.state_dict(), cfg.savepath+'/model-'+str(epoch+1))
至此,train.py
结束。
其中,结构损失函数structure_loss(pred, mask)
代码如下:
def structure_loss(pred, mask):
weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask)
wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none')
wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
pred = torch.sigmoid(pred)
inter = ((pred*mask)*weit).sum(dim=(2,3))
union = ((pred+mask)*weit).sum(dim=(2,3))
wiou = 1-(inter+1)/(union-inter+1)
return (wbce+wiou).mean()
test.py
在代码中对应为prediction_rgb.py
,具体代码如下:
首先选择要用的模型,然后载入训练权重路径,相关参数如下
ckpt_path
= ‘./saved_model’
exp_name
= ‘msnet’
args
= { ‘snapshot’: ‘model-50’, ‘crf_refine’: False, ‘save_results’: True }
if __name__ == '__main__':
main()
def main():
# 选择模型MSNet
net = MSNet().cuda()
# 打印使用的权重文件
print ('load snapshot \'%s\' for testing' % args['snapshot'])
# 载入权重路径: './saved_model/msnet/model-50'
net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']),map_location={'cuda:1': 'cuda:1'}))
# map_location={'cuda:1': 'cuda:1'} 参数指定了模型参数在哪个设备上进行加载。
# 'cuda:1' 表示将模型参数加载到 CUDA 设备的第一个索引上(即 GPU 设备)。
# 如果没有指定该参数,模型参数将默认加载到 CPU 上
# 启用验证模式
net.eval()
(1)其中,字典 to_test
:{IDRiD:‘/home_lv/guanyu.zhu/python/SmallSeg/MSNet-M2SNet/Datasets/IDRiD/TestDataset’}
(2)这段代码输入图片的原尺寸[4288,2848],然后resize成[512,512],很可能导致预测效果太差,后续还要换一种适合的方法。
with torch.no_grad():
# name:IDRiD ;
# root:'/home_lv/guanyu.zhu/python/SmallSeg/MSNet-M2SNet/Datasets/IDRiD/TestDataset'
for name, root in to_test.items():
print(root)
# 检查/创建文件夹:'./saved_model/msnet/(msnet)IDRiD_model-50'
check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
# img文件夹路径
root1 = os.path.join(root,'image') # '/home_lv/guanyu.zhu/python/SmallSeg/MSNet-M2SNet/Datasets/IDRiD/TestDataset/image'
img_list = [os.path.splitext(f) for f in os.listdir(root1)] # 每个img图片的名称,如 'IDRiD_75.jpg'
# idx:图片的索引值 ; img_name:tuple('图片名称', '.后缀名')
for idx, img_name in enumerate(img_list):
# 打印开始处理第几张图片
print ('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
# 读取图片
img = Image.open(os.path.join(root,'image',img_name[0]+img_name[1])).convert('RGB')
w_,h_ = img.size # w_: 4288 ; h_: 2848
img_resize = img.resize([512,512], Image.BILINEAR)
# 上一句代码将图片resize成[512,512]
# 该处经过函数:img_transform对图片进行ToTensor和Normalize操作
img_var = Variable(img_transform(img_resize).unsqueeze(0), volatile=True).cuda()
n, c, h, w = img_var.size() # n:1 w:512 c:3 h:512
读取完图片对其进行一个增强操作transformer.augment_image()
,经过net网络后得到预测输出model_output
,它再经过transformer.deaugment_mask()
处理得到deaug_mask
作为最后的预测结果。
mask = []
for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform()
rgb_trans = transformer.augment_image(img_var)
model_output = net(rgb_trans)
deaug_mask = transformer.deaugment_mask(model_output)
mask.append(deaug_mask)
prediction = torch.mean(torch.stack(mask, dim=0), dim=0)
prediction = prediction.sigmoid()
prediction = to_pil(prediction.data.squeeze(0).cpu()) # 512*512
prediction = prediction.resize((w_, h_), Image.BILINEAR) # resize(4288,2848)
保存预测结果图:文件夹路径:‘./saved_model/msnet/model-50epoch/IDRiD’
代码如下:
if args['save_results']:
check_mkdir(os.path.join(ckpt_path, exp_name,args['snapshot']+'epoch',name))
prediction.save(os.path.join(ckpt_path, exp_name ,args['snapshot']+'epoch',name, img_name[0] + '.png'))
class M2SNet(nn.Module):
# res2net based encoder decoder
def __init__(self):
super(M2SNet, self).__init__()
# ---- ResNet Backbone ----
self.resnet = res2net50_v1b_26w_4s(pretrained=True)
self.conv_3 = CNN1(64,3,1)
self.conv_5 = CNN1(64, 5, 2)
self.x5_dem_1 = nn.Sequential(nn.Conv2d(2048, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x4_dem_1 = nn.Sequential(nn.Conv2d(1024, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x3_dem_1 = nn.Sequential(nn.Conv2d(512, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x2_dem_1 = nn.Sequential(nn.Conv2d(256, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x5_x4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x4_x3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x3_x2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x2_x1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x5_x4_x3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x4_x3_x2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x3_x2_x1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x5_x4_x3_x2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x4_x3_x2_x1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
self.x5_dem_4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x5_x4_x3_x2_x1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
self.level3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.level2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.level1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.x5_dem_5 = nn.Sequential(nn.Conv2d(2048, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
self.output4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.output3 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.output2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True))
self.output1 = nn.Sequential(nn.Conv2d(64, 1, kernel_size=3, padding=1))
def forward(self, x):
input = x # [b,3,512,512]
# '''
x = self.resnet.conv1(x) # [b,64,256,256]
x = self.resnet.bn1(x) # [b,64,256,256]
x = self.resnet.relu(x) # [b,64,256,256]
x1 = self.resnet.maxpool(x) # [b,64,128,128]
# ---- low-level features ----
x2 = self.resnet.layer1(x1) # [b, 256, 128, 128]
x3 = self.resnet.layer2(x2) # [b, 512, 64, 64]
x4 = self.resnet.layer3(x3) # [b, 1024, 32, 32]
x5 = self.resnet.layer4(x4) # [b, 2048, 16, 16]
# '''
x5_dem_1 = self.x5_dem_1(x5)
x4_dem_1 = self.x4_dem_1(x4)
x3_dem_1 = self.x3_dem_1(x3)
x2_dem_1 = self.x2_dem_1(x2)
# 多尺度减法单元
x5_dem_1_up = F.upsample(x5_dem_1, size=x4.size()[2:], mode='bilinear') # 将x5_dem_1上采样到x4的h和w
x5_dem_1_up_map1 = self.conv_3(x5_dem_1_up) # [b,64,32,32] through self.conv_3,size of feature_map retain
x4_dem_1_map1 = self.conv_3(x4_dem_1) # [b,64,32,32]
x5_dem_1_up_map2 = self.conv_5(x5_dem_1_up) # [b,64,32,32] -> [b,64,32,32]
x4_dem_1_map2 = self.conv_5(x4_dem_1) # [b,64,32,32]
x5_4 = self.x5_x4(
abs(x5_dem_1_up - x4_dem_1)+abs(x5_dem_1_up_map1-x4_dem_1_map1)+abs(x5_dem_1_up_map2-x4_dem_1_map2))
x4_dem_1_up = F.upsample(x4_dem_1, size=x3.size()[2:], mode='bilinear')
x4_dem_1_up_map1 = self.conv_3(x4_dem_1_up)
x3_dem_1_map1 = self.conv_3(x3_dem_1)
x4_dem_1_up_map2 = self.conv_5(x4_dem_1_up)
x3_dem_1_map2 = self.conv_5(x3_dem_1)
x4_3 = self.x4_x3(
abs(x4_dem_1_up - x3_dem_1)+abs(x4_dem_1_up_map1-x3_dem_1_map1)+abs(x4_dem_1_up_map2-x3_dem_1_map2) )
x3_dem_1_up = F.upsample(x3_dem_1, size=x2.size()[2:], mode='bilinear')
x3_dem_1_up_map1 = self.conv_3(x3_dem_1_up)
x2_dem_1_map1 = self.conv_3(x2_dem_1)
x3_dem_1_up_map2 = self.conv_5(x3_dem_1_up)
x2_dem_1_map2 = self.conv_5(x2_dem_1)
x3_2 = self.x3_x2(
abs(x3_dem_1_up - x2_dem_1)+abs(x3_dem_1_up_map1-x2_dem_1_map1)+abs(x3_dem_1_up_map2-x2_dem_1_map2) )
x2_dem_1_up = F.upsample(x2_dem_1, size=x1.size()[2:], mode='bilinear')
x2_dem_1_up_map1 = self.conv_3(x2_dem_1_up)
x1_map1 = self.conv_3(x1)
x2_dem_1_up_map2 = self.conv_5(x2_dem_1_up)
x1_map2 = self.conv_5(x1)
x2_1 = self.x2_x1(abs(x2_dem_1_up - x1)+abs(x2_dem_1_up_map1-x1_map1)+abs(x2_dem_1_up_map2-x1_map2) )
x5_4_up = F.upsample(x5_4, size=x4_3.size()[2:], mode='bilinear')
x5_4_up_map1 = self.conv_3(x5_4_up)
x4_3_map1 = self.conv_3(x4_3)
x5_4_up_map2 = self.conv_5(x5_4_up)
x4_3_map2 = self.conv_5(x4_3)
x5_4_3 = self.x5_x4_x3(abs(x5_4_up - x4_3) +abs(x5_4_up_map1-x4_3_map1)+abs(x5_4_up_map2-x4_3_map2))
x4_3_up = F.upsample(x4_3, size=x3_2.size()[2:], mode='bilinear')
x4_3_up_map1 = self.conv_3(x4_3_up)
x3_2_map1 = self.conv_3(x3_2)
x4_3_up_map2 = self.conv_5(x4_3_up)
x3_2_map2 = self.conv_5(x3_2)
x4_3_2 = self.x4_x3_x2(abs(x4_3_up - x3_2)+abs(x4_3_up_map1-x3_2_map1)+abs(x4_3_up_map2-x3_2_map2) )
x3_2_up = F.upsample(x3_2, size=x2_1.size()[2:], mode='bilinear')
x3_2_up_map1 = self.conv_3(x3_2_up)
x2_1_map1 = self.conv_3(x2_1)
x3_2_up_map2 = self.conv_5(x3_2_up)
x2_1_map2 = self.conv_5(x2_1)
x3_2_1 = self.x3_x2_x1(abs(x3_2_up - x2_1)+abs(x3_2_up_map1-x2_1_map1)+abs(x3_2_up_map2-x2_1_map2) )
x5_4_3_up = F.upsample(x5_4_3, size=x4_3_2.size()[2:], mode='bilinear')
x5_4_3_up_map1 = self.conv_3(x5_4_3_up)
x4_3_2_map1 = self.conv_3(x4_3_2)
x5_4_3_up_map2 = self.conv_5(x5_4_3_up)
x4_3_2_map2 = self.conv_5(x4_3_2)
x5_4_3_2 = self.x5_x4_x3_x2(
abs(x5_4_3_up - x4_3_2)+abs(x5_4_3_up_map1-x4_3_2_map1)+abs(x5_4_3_up_map2-x4_3_2_map2) )
x4_3_2_up = F.upsample(x4_3_2, size=x3_2_1.size()[2:], mode='bilinear')
x4_3_2_up_map1 = self.conv_3(x4_3_2_up)
x3_2_1_map1 = self.conv_3(x3_2_1)
x4_3_2_up_map2 = self.conv_5(x4_3_2_up)
x3_2_1_map2 = self.conv_5(x3_2_1)
x4_3_2_1 = self.x4_x3_x2_x1(
abs(x4_3_2_up - x3_2_1) +abs(x4_3_2_up_map1-x3_2_1_map1)+abs(x4_3_2_up_map2-x3_2_1_map2))
x5_dem_4 = self.x5_dem_4(x5_4_3_2)
x5_dem_4_up = F.upsample(x5_dem_4, size=x4_3_2_1.size()[2:], mode='bilinear')
x5_dem_4_up_map1 = self.conv_3(x5_dem_4_up)
x4_3_2_1_map1 = self.conv_3(x4_3_2_1)
x5_dem_4_up_map2 = self.conv_5(x5_dem_4_up)
x4_3_2_1_map2 = self.conv_5(x4_3_2_1)
x5_4_3_2_1 = self.x5_x4_x3_x2_x1(
abs(x5_dem_4_up - x4_3_2_1)+abs(x5_dem_4_up_map1-x4_3_2_1_map1)+abs(x5_dem_4_up_map2-x4_3_2_1_map2) )
level4 = x5_4
level3 = self.level3(x4_3 + x5_4_3)
level2 = self.level2(x3_2 + x4_3_2 + x5_4_3_2)
level1 = self.level1(x2_1 + x3_2_1 + x4_3_2_1 + x5_4_3_2_1)
x5_dem_5 = self.x5_dem_5(x5)
output4 = self.output4(F.upsample(x5_dem_5,size=level4.size()[2:], mode='bilinear') + level4)
output3 = self.output3(F.upsample(output4,size=level3.size()[2:], mode='bilinear') + level3)
output2 = self.output2(F.upsample(output3,size=level2.size()[2:], mode='bilinear') + level2)
output1 = self.output1(F.upsample(output2,size=level1.size()[2:], mode='bilinear') + level1)
output = F.upsample(output1, size=input.size()[2:], mode='bilinear')
if self.training:
return output
return output
模型总体框架如下:
从上面代码中摘取多尺度减法单元
部分,如下:
# 多尺度减法单元
x5_dem_1_up = F.upsample(x5_dem_1, size=x4.size()[2:], mode='bilinear') # 将x5_dem_1上采样到x4的h和w
x5_dem_1_up_map1 = self.conv_3(x5_dem_1_up) # [b,64,32,32] through self.conv_3,size of feature_map retain
x4_dem_1_map1 = self.conv_3(x4_dem_1) # [b,64,32,32]
x5_dem_1_up_map2 = self.conv_5(x5_dem_1_up) # [b,64,32,32] -> [b,64,32,32]
x4_dem_1_map2 = self.conv_5(x4_dem_1) # [b,64,32,32]
x5_4 = self.x5_x4(
abs(x5_dem_1_up - x4_dem_1)+abs(x5_dem_1_up_map1-x4_dem_1_map1)+abs(x5_dem_1_up_map2-x4_dem_1_map2))
可视化模块图如下:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。