当前位置:   article > 正文

libtorch c++ 使用预训练权重(以resnet为例)_libtorch 自定义模型并加载权重

libtorch 自定义模型并加载权重

任务: 识别猫咪。

目录

1. 直接使用

1.1 获取预训练权重

 1.2 libtorch直接使用pt权重

2. 间接使用

2.1 BasicBlock

2.2 实现ResNet

2.3 BottleNeck


1. 直接使用

1.1 获取预训练权重

比如直接使用Pytorch版的预训练权重。先把权重保存下来,并打印分类类别(方便后面对比)

  1. import torch
  2. import torchvision.models as models
  3. from PIL import Image
  4. import numpy as np
  5. # input
  6. image = Image.open("E:\\code\\c++\\libtorch_models\\data\\cat.jpg") # 图片发在了build文件夹下
  7. image = image.resize((224, 224), Image.ANTIALIAS)
  8. image = np.asarray(image)
  9. image = image / 255.0
  10. image = torch.Tensor(image).unsqueeze_(dim=0) # (b,h,w,c)
  11. image = image.permute((0, 3, 1, 2)).float() # (b,h,w,c) -> (b,c,h,w)
  12. # model
  13. model = models.resnet18(pretrained=True)
  14. model = model.eval()
  15. resnet = torch.jit.trace(model, torch.rand(1, 3, 224, 224))
  16. # infer
  17. output = resnet(image)
  18. max_index = torch.max(output, 1)[1].item()
  19. print(max_index) # ImageNet1000类的类别序号
  20. resnet.save('resnet.pt')

将保存权重resnet.pt,并打印分类索引号是283,对应的是猫。

 1.2 libtorch直接使用pt权重

使用接口torch::jit::load 即可载入权重并获取resnet18模型。

然后再使用std::vector<torch::jit::IValue>传送数据到模型中,即可得到类别。

打印结果是283,和前面的pytorch版是一样。

  1. #include <iostream>
  2. #include <opencv.hpp>
  3. int main()
  4. {
  5. // load weights and model.
  6. auto resnet18 = torch::jit::load("E:\\code\\c++\\libtorch_models\\weights\\resnet18.pt");
  7. assert(module != nullptr);
  8. resnet18.to(torch::kCUDA);
  9. resnet18.eval();
  10. // pre
  11. cv::Mat image = cv::imread("E:\\code\\c++\\libtorch_models\\data\\cat.jpg");
  12. cv::resize(image, image, cv::Size(224, 224));
  13. torch::Tensor tensor_image = torch::from_blob(image.data, {224, 224,3 }, torch::kByte);
  14. tensor_image = torch::unsqueeze(tensor_image, 0).permute({ 0,3,1,2 }).to(torch::kCUDA).to(torch::kFloat).div(255.0); // (b,h,w,c) -> (b,c,h,w)
  15. std::cout << tensor_image.options() << std::endl;
  16. std::vector<torch::jit::IValue> inputs;
  17. inputs.push_back(tensor_image);
  18. // infer
  19. auto output = resnet18.forward(inputs).toTensor();
  20. auto max_result = output.max(1, true);
  21. auto max_index = std::get<1>(max_result).item<float>();
  22. std::cout << max_index << std::endl;
  23. return 0;
  24. }

2. 间接使用

间接使用是指基于libtorch c++ 复现一遍resnet网络,再利用前面得到的权重,初始化模型。输出结果依然是283.

  1. #include <iostream>
  2. #include "resnet.h" // libtorch实现的resnet
  3. #include <opencv.hpp>
  4. int main()
  5. {
  6. // load weights and model.
  7. ResNet resnet = resnet18(1000); // orig net
  8. torch::load(resnet, "E:\\code\\c++\\libtorch_models\\weights\\resnet18.pt"); // load weights.
  9. assert(resnet != nullptr);
  10. resnet->to(torch::kCUDA);
  11. resnet->eval();
  12. // pre
  13. cv::Mat image = cv::imread("E:\\code\\c++\\libtorch_models\\data\\cat.jpg");
  14. cv::resize(image, image, cv::Size(224, 224));
  15. torch::Tensor tensor_image = torch::from_blob(image.data, { 224, 224,3 }, torch::kByte);
  16. tensor_image = torch::unsqueeze(tensor_image, 0).permute({ 0,3,1,2 }).to(torch::kCUDA).to(torch::kFloat).div(255.0); // (b,h,w,c) -> (b,c,h,w)
  17. std::cout << tensor_image.options() << std::endl;
  18. // infer
  19. auto output = resnet->forward(tensor_image);
  20. auto max_result = output.max(1, true);
  21. auto max_index = std::get<1>(max_result).item<float>();
  22. std::cout << max_index << std::endl;
  23. return 0;
  24. }

 接下来介绍resnet详细实现过程。

2.1 BasicBlock

先实现resnet最小单元BasicBlock,该单元是两次卷积组成的残差块。结构如下。

两种形式,如果第一个卷积stride等于2进行下采样,则跳层连接也需要下采样,维度才能一致,再进行对应相加。

  1. // resnet18 and resnet34
  2. class BasicBlockImpl : public torch::nn::Module {
  3. public:
  4. BasicBlockImpl(int64_t in_channels, int64_t out_channels, int64_t stride, torch::nn::Sequential downsample);
  5. torch::Tensor forward(torch::Tensor x);
  6. public:
  7. torch::nn::Sequential downsample{ nullptr };
  8. private:
  9. torch::nn::Conv2d conv1{ nullptr };
  10. torch::nn::BatchNorm2d bn1{ nullptr };
  11. torch::nn::Conv2d conv2{ nullptr };
  12. torch::nn::BatchNorm2d bn2{ nullptr };
  13. };
  14. TORCH_MODULE(BasicBlock);
  15. // other resnet using BottleNeck
  16. class BottleNeckImpl : public torch::nn::Module {
  17. public:
  18. BottleNeckImpl(int64_t in_channels, int64_t out_channels, int64_t stride,
  19. torch::nn::Sequential downsample, int groups, int base_width);
  20. torch::Tensor forward(torch::Tensor x);
  21. public:
  22. torch::nn::Sequential downsample{ nullptr };
  23. private:
  24. torch::nn::Conv2d conv1{ nullptr };
  25. torch::nn::BatchNorm2d bn1{ nullptr };
  26. torch::nn::Conv2d conv2{ nullptr };
  27. torch::nn::BatchNorm2d bn2{ nullptr };
  28. torch::nn::Conv2d conv3{ nullptr };
  29. torch::nn::BatchNorm2d bn3{ nullptr };
  30. };
  31. TORCH_MODULE(BottleNeck);
  1. // conv3x3+bn+relu, conv3x3+bn,
  2. // downsample: 用来对原始输入进行下采样.
  3. // stride: 控制是否下采样,stride=2则是下采样,且downsample将用于对原始输入进行下采样.
  4. BasicBlockImpl::BasicBlockImpl(int64_t in_channels, int64_t out_channels, int64_t stride, torch::nn::Sequential downsample) {
  5. this->downsample = downsample;
  6. conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 3).stride(stride).padding(1).bias(false));
  7. bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels));
  8. conv2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(out_channels, out_channels, 3).stride(1).padding(1).bias(false));
  9. bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels));
  10. register_module("conv1", conv1);
  11. register_module("bn1", bn1);
  12. register_module("conv2", conv2);
  13. register_module("bn2", bn2);
  14. if (!downsample->is_empty()) {
  15. register_module("downsample", downsample);
  16. }
  17. }
  18. torch::Tensor BasicBlockImpl::forward(torch::Tensor x) {
  19. torch::Tensor identity = x.clone();
  20. x = conv1->forward(x); // scale/2. or keep scale unchange.
  21. x = bn1->forward(x);
  22. x = torch::relu(x);
  23. x = conv2->forward(x);
  24. x = bn2->forward(x);
  25. // 加入x的维度减半,则原始输入必须也减半。
  26. if (!downsample->is_empty()) identity = downsample->forward(identity);
  27. x += identity;
  28. x = torch::relu(x);
  29. return x;
  30. }

2.2 实现ResNet

这里以resnet18为例。网络结构如下。

简单一句话,使用残差块多次卷积,最后接一个全链接层进行分类。

注意上图中的layer1到layer4是由BasicBlock0和BasicBlock1两种残差块组成。实现如下。

  1. // out_channels: 每一个block输出的通道数。
  2. // blocks: 每个layer包含的blocks数.
  3. torch::nn::Sequential ResNetImpl::_make_layer(int64_t out_channels, int64_t blocks, int64_t stride) {
  4. // 1, downsampe: stride or channel
  5. torch::nn::Sequential downsample;
  6. if (stride != 1 || this->in_channels != out_channels * expansion) { // 步长等于2,或者输入通道不等于输出通道,则都是接conv操作,改变输入x的维度
  7. downsample = torch::nn::Sequential(
  8. torch::nn::Conv2d(torch::nn::Conv2dOptions(this->in_channels, out_channels * this->expansion, 1).stride(stride).padding(0).groups(1).bias(false)),
  9. torch::nn::BatchNorm2d(out_channels * this->expansion)
  10. );
  11. }
  12. // 2, layers: first is downsample and others are conv with 1 stride.
  13. torch::nn::Sequential layers;
  14. if (this->is_basic) {
  15. layers->push_back(BasicBlock(this->in_channels, out_channels, stride, downsample)); // 控制是否下采样
  16. this->in_channels = out_channels; // 更新输入通道,以备下次使用
  17. for (int64_t i = 1; i < blocks; i++) { // 剩余的block都是in_channels == out_channels. and stride = 1.
  18. layers->push_back(BasicBlock(this->in_channels, this->in_channels, 1, torch::nn::Sequential())); // 追加多个conv3x3,且不改变维度
  19. }
  20. }
  21. else {
  22. layers->push_back(BottleNeck(this->in_channels, out_channels, stride, downsample, this->groups, this->base_width));
  23. this->in_channels = out_channels * this->expansion; // 更新输入通道,以备下次使用
  24. for (int64_t i = 1; i < blocks; i++) { // 剩余的block都是in_channels == out_channels. and stride = 1.
  25. layers->push_back(BottleNeck(this->in_channels, this->in_channels, 1, torch::nn::Sequential(), this->groups, this->base_width));
  26. }
  27. }
  28. return layers;
  29. }

resnet实现。 

  1. class ResNetImpl : public torch::nn::Module {
  2. public:
  3. ResNetImpl(std::vector<int> layers, int num_classes, std::string model_type,
  4. int groups, int width_per_group);
  5. torch::Tensor forward(torch::Tensor x);
  6. public:
  7. torch::nn::Sequential _make_layer(int64_t in_channels, int64_t blocks, int64_t stride = 1);
  8. private:
  9. int expansion = 1; // 通道扩大倍数,resnet50会用到
  10. bool is_basic = true; // 是BasicBlock,还是BottleNeck
  11. int in_channels = 64; // 记录输入通道数
  12. int groups = 1, base_width = 64;
  13. torch::nn::Conv2d conv1{ nullptr };
  14. torch::nn::BatchNorm2d bn1{ nullptr };
  15. torch::nn::Sequential layer1{ nullptr };
  16. torch::nn::Sequential layer2{ nullptr };
  17. torch::nn::Sequential layer3{ nullptr };
  18. torch::nn::Sequential layer4{ nullptr };
  19. torch::nn::Linear fc{ nullptr };
  20. };
  21. TORCH_MODULE(ResNet);
  1. // layers: resnet18: { 2, 2, 2, 2 }, resnet34: { 3, 4, 6, 3 }, resnet50: { 3, 4, 6, 3 };
  2. ResNetImpl::ResNetImpl(std::vector<int> layers, int num_classes = 1000, std::string model_type = "resnet18", int groups = 1, int width_per_group = 64) {
  3. if (model_type != "resnet18" && model_type != "resnet34") // 即不使用BasicBlock,使用BottleNeck
  4. {
  5. this->expansion = 4;
  6. is_basic = false;
  7. }
  8. this->groups = groups; // 1
  9. this->base_width = base_width; // 64
  10. this->conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 64, 7).stride(2).padding(3).groups(1).bias(false)); // scale/2
  11. this->bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64));
  12. this->layer1 = torch::nn::Sequential(_make_layer(64, layers[0])); // stride=1, scale and channels unchange
  13. this->layer2 = torch::nn::Sequential(_make_layer(128, layers[1], 2)); // stride=2, scale/2. channels double
  14. this->layer3 = torch::nn::Sequential(_make_layer(256, layers[2], 2)); // stride=2, scale/2. channels double
  15. this->layer4 = torch::nn::Sequential(_make_layer(512, layers[3], 2)); // stride=2, scale/2. channels double
  16. this->fc = torch::nn::Linear(512 * this->expansion, num_classes);
  17. register_module("conv1", conv1);
  18. register_module("bn1", bn1);
  19. register_module("layer1", layer1);
  20. register_module("layer2", layer2);
  21. register_module("layer3", layer3);
  22. register_module("layer4", layer4);
  23. register_module("fc", fc);
  24. }
  25. torch::Tensor ResNetImpl::forward(torch::Tensor x) {
  26. // 1,先是两次下采样. (b,3,224,224) -> (b,64,56,56)
  27. x = conv1->forward(x); // (b,3,224,224)->(b,64,112,112)
  28. x = bn1->forward(x);
  29. x = torch::relu(x); // feature 1
  30. x = torch::max_pool2d(x, 3, 2, 1); // k=3,s=2,p=1. (b,64,112,112)->(b,64,56,56)
  31. x = layer1->forward(x); // feature 2. (b,64,56,56)
  32. x = layer2->forward(x); // feature 3. (b,128,28,28)
  33. x = layer3->forward(x); // feature 4. (b,256,14,14)
  34. x = layer4->forward(x); // feature 5. (b,512,7,7)
  35. x = torch::adaptive_avg_pool2d(x, {1, 1}); // (b,512,1,1)
  36. //x = torch::avg_pool2d(x, 7, 1); // (b,512,1,1)
  37. x = x.view({ x.sizes()[0], -1 }); // (b,512)
  38. x = fc->forward(x); // (b,1000)
  39. return torch::log_softmax(x, 1); // score (负无穷,0]
  40. }

 创建resnet18和resnet34。其中layers中的数字代表当前layer中包含的BasicBlock个数。

  1. // 创建不同resnet分类网络的函数
  2. ResNet resnet18(int64_t num_classes) {
  3. std::vector<int> layers = { 2, 2, 2, 2 };
  4. ResNet model(layers, num_classes, "resnet18");
  5. return model;
  6. }
  7. ResNet resnet34(int64_t num_classes) {
  8. std::vector<int> layers = { 3, 4, 6, 3 };
  9. ResNet model(layers, num_classes, "resnet34");
  10. return model;
  11. }

2.3 BottleNeck

resnet系列框架是一样的,不同点是组件有差异。 

resnet18和resnet34都是用BasicBlock组件,而resnet50及以上则使用BottleNeck结构。如下所示。

BottleNeck有三种形式:

(1)BottleNeck0: stride=1, only 4*channels;

(2)BottleNeck1: stride=1, only 4*channels;

(3)BottleNeck2: stride=2, 4*channels and scales/2

  1. // other resnet using BottleNeck
  2. class BottleNeckImpl : public torch::nn::Module {
  3. public:
  4. BottleNeckImpl(int64_t in_channels, int64_t out_channels, int64_t stride,
  5. torch::nn::Sequential downsample, int groups, int base_width);
  6. torch::Tensor forward(torch::Tensor x);
  7. public:
  8. torch::nn::Sequential downsample{ nullptr };
  9. private:
  10. torch::nn::Conv2d conv1{ nullptr };
  11. torch::nn::BatchNorm2d bn1{ nullptr };
  12. torch::nn::Conv2d conv2{ nullptr };
  13. torch::nn::BatchNorm2d bn2{ nullptr };
  14. torch::nn::Conv2d conv3{ nullptr };
  15. torch::nn::BatchNorm2d bn3{ nullptr };
  16. };
  17. TORCH_MODULE(BottleNeck);
  1. // stride: 控制是否下采样,stride=2则是下采样,且downsample将用于对原始输入进行下采样.
  2. // conv1x1+bn+relu, conv3x3+bn+relu, conv1x1+bn+relu
  3. BottleNeckImpl::BottleNeckImpl(int64_t in_channels, int64_t out_channels, int64_t stride,
  4. torch::nn::Sequential downsample, int groups, int base_width) {
  5. this->downsample = downsample;
  6. // 64 * (64 / 64) / 1 = 64, 128 * (64 / 64) / 1 = 128, 128 * (64 / 64) / 2 = 64.
  7. int width = int(out_channels * (base_width / 64.)) * groups; // 64 * (64/64) / 1. 当前的输出通道数
  8. // 1x1 conv
  9. conv1 = torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, width, 1).stride(1).padding(0).groups(1).bias(false));
  10. bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
  11. // 3x3 conv
  12. conv2 = torch::nn::Conv2d(torch::nn::Conv2dOptions(width, width, 3).stride(stride).padding(1).groups(groups).bias(false));
  13. bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
  14. // 1x1 conv
  15. conv3 = torch::nn::Conv2d(torch::nn::Conv2dOptions(width, out_channels * 4, 1).stride(1).padding(0).groups(1).bias(false));
  16. bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels * 4));
  17. register_module("conv1", conv1);
  18. register_module("bn1", bn1);
  19. register_module("conv2", conv2);
  20. register_module("bn2", bn2);
  21. register_module("conv3", conv3);
  22. register_module("bn3", bn3);
  23. if (!downsample->is_empty()) {
  24. register_module("downsample", downsample);
  25. }
  26. }
  27. torch::Tensor BottleNeckImpl::forward(torch::Tensor x) {
  28. torch::Tensor identity = x.clone();
  29. // conv1x1+bn+relu
  30. x = conv1->forward(x);
  31. x = bn1->forward(x);
  32. x = torch::relu(x);
  33. // conv3x3+bn+relu
  34. x = conv2->forward(x); // if stride==2, scale/2
  35. x = bn2->forward(x);
  36. x = torch::relu(x);
  37. // conv1x1+bn+relu
  38. x = conv3->forward(x); // double channels
  39. x = bn3->forward(x);
  40. if (!downsample->is_empty()) identity = downsample->forward(identity);
  41. x += identity;
  42. x = torch::relu(x);
  43. return x;
  44. }

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/451119
推荐阅读
相关标签
  

闽ICP备14008679号