当前位置:   article > 正文

ResNeXt网络详解并使用pytorch搭建模型_resnetxt

resnetxt

1.ResNeXt网络详解

网络创新点:
(1)更新了block

组卷积(Group Convolution)

普通卷积
在这里插入图片描述
假设高和宽为k的输入特征矩阵channel=Cin,使用n个channel=Cin的卷积核进行处理,得到channel=n的输出特征矩阵。

此时需要的参数个数为输入特征矩阵的高x宽xchannelx卷积核个数
输入特征矩阵的高 x 宽 x channel x 卷积核个数

组卷积
在这里插入图片描述
在上图中,假设输入特征矩阵channel=Cin,尺寸为k x k,

将输入特征矩阵的channel划分成g个组,将这g个组分别进行卷积操作

如果使用n/g个卷积核,就能得到g个组的channel=n/g的输出特征矩阵,

通过GConv所需要的参数为=
其中Cin/g是由于将输入特征矩阵的channel划分成了g组,所以对每一个组的每一个卷积核个数为Cin/g;n/g是由于为了得到channel=n的特征矩阵,对于每个组,就需要使用n/g个卷积核;g是由于划分了g个组。

GConv相比较普通卷积,所需要的参数减少了1/g。

如果分组个数g、输入特征矩阵的channel与输出特征矩阵的channel一致时,此时相当于对输入特征矩阵的每一个channel分配了一个channel为1的卷积核进行卷积,此时就是DWConv

DW卷积MobileNetv1、v2网络详解

ResNetXt的block

在这里插入图片描述
上图三个block在数学计算上完全等价

在( c )中,首先通过1x1的卷积层将输入特征矩阵的channel从256降维到128;再通过3x3的32组group卷积对其进行处理;再通过1x1的卷积层进行将特征矩阵的channel从128升维到256;最后主分支与捷径分支的输出进行相加得到最终输出。

因此可以将ResNett网络中的Block替换成使用组卷积的Block

ResNetXt的网络结构

在这里插入图片描述
图中ResNeXt-50(32 x 4d)中的32对应group数、4d对应组卷积中每一个组所采用的卷积核个数(4x32=128)。

为什么group数要设置为32:原论文作者得出结论group数越大,错误率越低。

且block层数小于3,组卷积block的就没有意义。

2.使用pytorch搭建模型

(1)model.py

只需要在深层次的ResNet网络中进行修改即可:

    def __init__(self, in_channel, out_channel, stride=1, downsample=None,                     #定义初始函数及残差结构所需要使用的一系列层结构
                 groups=1, width_per_group=64):                                                #resnext多传入groups和width_per_group
        super(Bottleneck, self).__init__()

        width = int(out_channel * (width_per_group / 64.)) * groups                            #计算resnet和rennext网络第一个卷积层和第二个卷积层采用的卷积核个数

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,                     #对于resnet,不传入groups和width_per_group两个参数,out_channels=width=out_channels
                               kernel_size=1, stride=1, bias=False)  # squeeze channels        #对于resnext,传入这两个参数,width等于两倍的resnet的out_channels
        self.bn1 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,           
                               kernel_size=3, stride=stride, bias=False, padding=1)            #步长为2,因此这里步长根据传入的stride调整
        self.bn2 = nn.BatchNorm2d(width)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,     #卷积核个数为四倍的前一层卷积核个数
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

实例化ResNetXt-50

def resnext50_32x4d(num_classes=1000, include_top=True):                   #
    # https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth     #权重下载地址
    groups = 32
    width_per_group = 4
    return ResNet(Bottleneck, [3, 4, 6, 3],                                #调用ResNet类,与之前ResNet相同,但多了两个参数
                  num_classes=num_classes,
                  include_top=include_top,
                  groups=groups,                                           #多传入了这两个参数
                  width_per_group=width_per_group)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

(2)train.py

如果要使用ResNeXt模型,要做如下更改
导入模块的更改

from model import resnet34
更改为:from model import resnext50_32x4d
  • 1
  • 2

实例化模型的更改

    net = resnext50_32x4d()
  • 1

预训练权重更改

model_weight_path = "./resnext50_32x4d.pth" 
  • 1

不使用预训练权重

    # for param in net.parameters():          #将这两行代码隐去
    #     param.requires_grad = False
  • 1
  • 2

使用预训练权重就不隐去,且将网络的最后一层全连接层替换成自己的全连接层

    net.fc = nn.Linear(in_channel, 5)          #花分类只有5个类别,重新赋值全连接层
  • 1

训练权重保存路径修改

    save_path = './resNeXt50.pth'
  • 1

(3)标题数据集

同:ResNet网络详解并使用pytorch搭建模型、并基于迁移学习训练

(4)predict.py

导入模块的更改

from model import resnet34
更改为:from model import resnext50_32x4d
  • 1
  • 2

实例化模型的更改

model = resnext50_32x4d(num_classes=5).to(device)
  • 1

预训练权重更改

weight_path = "./resNeXt50.pth" 
  • 1

(4)batch_predict.py

批量预测图片脚本

# load image
    # 指向需要遍历预测的图像文件夹
    imgs_root = "/data/imgs"
    assert os.path.exists(imgs_root), f"file: '{imgs_root}' dose not exist."
    # 读取指定文件夹下所有jpg图像路径
    img_path_list = [os.path.join(imgs_root, i) for i in os.listdir(imgs_root) if i.endswith(".jpg")]

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), f"file: '{json_path}' dose not exist."

    json_file = open(json_path, "r")
    class_indict = json.load(json_file)

    # create model
    model = resnet34(num_classes=5).to(device)

    # load model weights
    weights_path = "./resNeXt50.pth"
    assert os.path.exists(weights_path), f"file: '{weights_path}' dose not exist."
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # prediction
    model.eval()
    batch_size = 8  # 每次预测时将多少张图片打包成一个batch
    with torch.no_grad():
        for ids in range(0, len(img_path_list) // batch_size):
            img_list = []
            for img_path in img_path_list[ids * batch_size: (ids + 1) * batch_size]:
                assert os.path.exists(img_path), f"file: '{img_path}' dose not exist."
                img = Image.open(img_path)
                img = data_transform(img)
                img_list.append(img)

            # batch img
            # 将img_list列表中的所有图像打包成一个batch
            batch_img = torch.stack(img_list, dim=0)
            # predict class
            output = model(batch_img.to(device)).cpu()
            predict = torch.softmax(output, dim=1)
            probs, classes = torch.max(predict, dim=1)

            for idx, (pro, cla) in enumerate(zip(probs, classes)):
                print("image: {}  class: {}  prob: {:.3}".format(img_path_list[ids * batch_size + idx],
                                                                 class_indict[str(cla.numpy())],
                                                                 pro.numpy()))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/373349?site
推荐阅读
相关标签
  

闽ICP备14008679号