当前位置:   article > 正文

libtorch-resnet18_class resnet18(nn.module): def __init__(self, num_

class resnet18(nn.module): def __init__(self, num_class): super(resnet18, se

与大家分享一下自己在学习使用libtorch搭建神经网络时学到的一些心得和例子,记录下来供大家参考
首先我们要参考着pytorch版的resnet来搭建,这样我们可以省去不必要的麻烦,上代码:
1、首先是pytorch版残差模块

class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 3, stride, 1),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, 3, 1, 1),
            nn.BatchNorm2d(outchannel)
        )
        self.right = shortcut

    def forward(self, x):
        out = self.left(x)
        residual = x if self.right is None else self.right(x)
        out += residual
        return F.relu(out)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

2、libtorch版残差模块
因为是用c++搭建的,所以先创建头文件
2.1残差模块头文件(声明)

//重载函数
inline torch::nn::Conv2dOptions conv_options(int64_t in_planes, int64_t out_planes, int64_t kerner_size,
	int64_t stride = 1, int64_t padding = 0, int groups = 1, bool with_bias = true) {
	torch::nn::Conv2dOptions conv_options = torch::nn::Conv2dOptions(in_planes, out_planes, kerner_size);
	conv_options.stride(stride);
	conv_options.padding(padding);
	conv_options.bias(with_bias);
	conv_options.groups(groups);
	return conv_options;
}
//残差模块声明
class Block_ocrImpl : public torch::nn::Module {
public:
    Block_ocrImpl(int64_t inplanes, int64_t planes, int64_t stride_ = 1,
		torch::nn::Sequential downsample_ = nullptr, int groups = 1, int base_width = 64, bool is_basic = true);
	torch::Tensor forward(torch::Tensor x);
	torch::nn::Sequential downsample{ nullptr };
private:
	bool is_basic = true;
	int64_t stride = 1;
	torch::nn::Conv2d conv1{ nullptr };
	torch::nn::BatchNorm2d bn1{ nullptr };
	torch::nn::Conv2d conv2{ nullptr };
	torch::nn::BatchNorm2d bn2{ nullptr };
	torch::nn::Conv2d conv3{ nullptr };
	torch::nn::BatchNorm2d bn3{ nullptr };
};
TORCH_MODULE(Block_ocr);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

2.2残差模块定义
这里我们要在头文件里面写一个卷积的重载函数,省去以后重复写的工作,我把它放在了2的头文件里面

//残差模块定义
Block_ocrImpl::Block_ocrImpl(int64_t inplanes, int64_t planes, int64_t stride_,
    torch::nn::Sequential downsample_, int groups, int base_width, bool _is_basic)
{
    downsample = downsample_;
    stride = stride_;
    int width = int(planes * (base_width / 64.)) * groups;

    conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 3, stride_, 1, groups, false));
    bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
    conv2 = torch::nn::Conv2d(conv_options(width, width, 3, 1, 1, groups, false));
    bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
    is_basic = _is_basic;
    if (!is_basic) {
        conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 1, 1, 0, 1, false));
        conv2 = torch::nn::Conv2d(conv_options(width, width, 3, stride_, 1, groups, false));
        conv3 = torch::nn::Conv2d(conv_options(width, planes * 4, 1, 1, 0, 1, false));
        bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * 4));
    }

    register_module("conv1", conv1);
    register_module("bn1", bn1);
    register_module("conv2", conv2);
    register_module("bn2", bn2);
    if (!is_basic) {
        register_module("conv3", conv3);
        register_module("bn3", bn3);
    }

    if (!downsample->is_empty()) {
        register_module("downsample", downsample);
    }
}
//残差前向传播
torch::Tensor Block_ocrImpl::forward(torch::Tensor x) {
    torch::Tensor residual = x.clone();

    x = conv1->forward(x);
    x = bn1->forward(x);
    x = torch::relu(x);

    x = conv2->forward(x);
    x = bn2->forward(x);

    if (!is_basic) {
        x = torch::relu(x);
        x = conv3->forward(x);
        x = bn3->forward(x);
    }

    if (!downsample->is_empty()) {
        residual = downsample->forward(residual);
    }

    x += residual;
    x = torch::relu(x);

    return x;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

3、pytorch版resnet主函数

class ResNet18(nn.Module):
    def __init__(self,nc):
        super(ResNet18, self).__init__()
        ###网络输入部分由一个7x7stride=2的卷积核和一个3x3stride=2的最大池化组成
        self.pre = nn.Sequential(
            nn.Conv2d(nc, 64, 7, 2, 3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(3, 2, 1),
        )
        ###网络中间卷积部分,通过中间3x3的卷积堆叠来实现信息的提取,下面的2代表bolck的重复堆叠次数
        self.layer1 = self._make_layer(64, 128, 1)

        self.layer2 = self._make_layer(128, 256, 2, stride=(2, 1))

        self.layer3 = self._make_layer(256, 512, 5, stride=(2, 1))

        self.layer4 = self._make_layer(512, 512, 3, stride=(2, 1))


    def _make_layer(self, inchannel, outchannel, block_num, stride=(1, 1)):
        shortcut = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, 1, stride),
            nn.BatchNorm2d(outchannel)
        )
        layers = []
        layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))  # 改变通道数量
        for i in range(1, block_num + 1):
            layers.append(ResidualBlock(outchannel, outchannel))
        return nn.Sequential(*layers)

    ###规定网络数据流向
    def forward(self, x):
        x = self.pre(x)  ###[2,3,32,280]--->[2,64,8,70]
        x = self.layer1(x)  ###[2,64,8,70]
        x = self.layer2(x)  ###[2,128,4,35]
        x = self.layer3(x)  ###[2,256,2,17]
        x = self.layer4(x)  ###[2,512,1,8]
        return x

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40

4、libtorch版主函数
和残差模块一样,分为头文件(.h)和源文件(.cpp)
先写头文件,还是仿照pytorch版的来写,这样我们可以避免很多麻烦
4.1主函数头文件(声明)

//主函数声明
class ResNet_ocrImpl : public torch::nn::Module {
public:
    ResNet_ocrImpl(/*std::vector<int> layers, int num_classes = 1000,*/ std::string model_type = "resnet18",
        int groups = 1, int width_per_group = 64);
    torch::Tensor forward(torch::Tensor x);
    std::vector<torch::Tensor> features(torch::Tensor x);
    torch::nn::Sequential _make_layer(int64_t planes, int64_t blocks, int64_t stride = 1);
private:
    int expansion = 1; bool is_basic = true;
    int64_t inplanes = 64; int groups = 1; int base_width = 64;
    torch::nn::Conv2d conv1{ nullptr };
    torch::nn::BatchNorm2d bn1{ nullptr };
    torch::nn::Sequential layer1{ nullptr };
    torch::nn::Sequential layer2{ nullptr };
    torch::nn::Sequential layer3{ nullptr };
    torch::nn::Sequential layer4{ nullptr };
};
TORCH_MODULE(ResNet_ocr);
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

4.2主函数定义

//先定义层函数_make_layer,这里也是参照pytorch写的
torch::nn::Sequential ResNet_ocrImpl::_make_layer(int64_t planes, int64_t blocks, int64_t stride) {

    torch::nn::Sequential downsample;
    if (stride != 1 || inplanes != planes * expansion) {
        downsample = torch::nn::Sequential(
            torch::nn::Conv2d(conv_options(inplanes, planes * expansion, 1, stride, 0, 1, false)),
            torch::nn::BatchNorm2d(planes * expansion)
        );
    }
    torch::nn::Sequential layers;
    layers->push_back(Block_ocr(inplanes, planes, stride, downsample, groups, base_width, is_basic));
    inplanes = planes * expansion;
    for (int64_t i = 1; i < blocks; i++) {
        layers->push_back(Block_ocr(inplanes, planes, 1, torch::nn::Sequential(), groups, base_width, is_basic));
    }

    return layers;
}
//然后定义主函数
ResNet_ocrImpl::ResNet_ocrImpl(/*std::vector<int> layers, int num_classes,*/ std::string model_type, int _groups, int _width_per_group)
{
    if (model_type != "resnet18" && model_type != "resnet34")
    {
        expansion = 4;
        is_basic = false;
    }
    groups = _groups;
    base_width = _width_per_group;
    conv1 = torch::nn::Conv2d(conv_options(1, 64, 7, 2, 3, 1, false));
    bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64));
    layer1 = torch::nn::Sequential(_make_layer(64, 2/*layers[0]*/));
    layer2 = torch::nn::Sequential(_make_layer(128, 2/*layers[1]*/, 2));
    layer3 = torch::nn::Sequential(_make_layer(256,2 /*layers[2]*/, 2));
    layer4 = torch::nn::Sequential(_make_layer(512, 2/*layers[3]*/, 2));
    register_module("conv1", conv1);
    register_module("bn1", bn1);
    register_module("layer1", layer1);
    register_module("layer2", layer2);
    register_module("layer3", layer3);
    register_module("layer4", layer4);
    for (auto& module : modules(/*include_self=*/false)) {
        				if (auto M = dynamic_cast<torch::nn::Conv2dImpl*>(module.get())) {
        					torch::nn::init::kaiming_normal_(
        						M->weight,
        						/*a=*/0,
        						torch::kFanOut,
        						torch::kReLU);
        				}
        				else if (auto M = dynamic_cast<torch::nn::BatchNorm2dImpl*>(module.get())) {
        					torch::nn::init::constant_(M->weight, 1);
        					torch::nn::init::constant_(M->bias, 0);
        				}
        			}
        	
}

//resnet主函数-前向传播
torch::Tensor  ResNet_ocrImpl::forward(torch::Tensor x) {
    x = conv1->forward(x);
    x = bn1->forward(x);
    x = torch::relu(x);
    x = torch::max_pool2d(x, 3, 2, 1);

    x = layer1->forward(x);
    x = layer2->forward(x);
    x = layer3->forward(x);
    x = layer4->forward(x);
    return x;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70

以上就是;libtorch版的resnet18 网络,完全使用c++搭建的,由于我用resnet需要和别的网络拼接,所以fc层和softmax层给删了,有需要的可以自己填上。这里也是参考一位github大神的手法来写的。
科技无罪、知识无罪,我们要做知识的传播者!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/73101?site
推荐阅读
相关标签
  

闽ICP备14008679号