当前位置:   article > 正文

Fully Convolutional Networks for Semantic Segmentation ———— 全卷积网络FCN代码解读之模型架构_全卷积模型架构

全卷积模型架构

Fully Convolutional Networks for Semantic Segmentation–用于语义分割的全卷积网络



一、数据预处理

  1. 标签处理
  2. 标签编码
  3. 可视化编码过程
  4. 定义预处理类

1.标签处理

利用if函数和file_path(list),连接数据标签路径,并裁剪图片大小

代码如下:

class CamvidDataset(Dataset):
    def __init__(self, file_path=[], crop_size=None):
        if len(file_path) != 2:
            raise ValueError("同时需要图片和标签文件夹的路径,图片路径在前")
            #保证正确读入图片和标签路径,逻辑是判断是否是2个元素,是继续执行,否则提示valueError
        self.img_path = file_path[0]
        self.label_path = file_path[1]
        self.imgs = self.read_file(self.img_path)
        self.labels = self.read_file(self.label_path)
        self.crop_size = crop_size
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

数据处理初始化将图片和标签路径提取出,保持图片路径在前

2.标签编码

利用哈希算法形成一对一或者多对一的映射关系,形成颜色到标签的对应关系。
编码函数: (p[0]*256+p[1])*256+p[2]
**原理:**一个像素点由编码函数转化为整数,将整数作为这个像素点在哈希表中的索引,并查到相对应的类别。

二、模型搭建

1.引入库

代码如下:

import numpy as np
import torch
from torchvision import models
from torch import nn
  • 1
  • 2
  • 3
  • 4

2.模型架构

在这里插入图片描述
输入图像前阶段卷积池化采用VGG网络
代码如下:

class FCN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.stage1 = pretrained_net.features[:7]
        self.stage2 = pretrained_net.features[7:14]
        self.stage3 = pretrained_net.features[14:24]
        self.stage4 = pretrained_net.features[24:34]
        self.stage5 = pretrained_net.features[34:]

        self.scores1 = nn.Conv2d(512, num_classes, 1)
        self.scores2 = nn.Conv2d(512, num_classes, 1)
        self.scores3 = nn.Conv2d(128, num_classes, 1)

        self.conv_trans1 = nn.Conv2d(512, 256, 1)
        self.conv_trans2 = nn.Conv2d(256, num_classes, 1)

        self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, 16, 8, 4, bias=False)
        self.upsample_8x.weight.data = bilinear_kernel(num_classes, num_classes, 16)

        self.upsample_2x_1 = nn.ConvTranspose2d(512, 512, 4, 2, 1, bias=False)
        self.upsample_2x_1.weight.data = bilinear_kernel(512, 512, 4)

        self.upsample_2x_2 = nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False)
        self.upsample_2x_2.weight.data = bilinear_kernel(256, 256, 4)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

FCN-8s网络搭建

    def forward(self, x):
        s1 = self.stage1(x)
        s2 = self.stage2(s1)
        s3 = self.stage3(s2)
        s4 = self.stage4(s3)
        s5 = self.stage5(s4)

        scores1 = self.scores1(s5)
        s5 = self.upsample_2x_1(s5)
        add1 = s5 + s4

        scores2 = self.scores2(add1)

        add1 = self.conv_trans1(add1)
        add1 = self.upsample_2x_2(add1)
        add2 = add1 + s3

        output = self.conv_trans2(add2)
        output = self.upsample_8x(output)
        return output
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/72320
推荐阅读
相关标签
  

闽ICP备14008679号