当前位置:   article > 正文

pytorch修改resnet18 输入通道_resnet 如何将输入的3通道改为1通道

resnet 如何将输入的3通道改为1通道

方法一:扩张1通道为3通道,利用torch.expand()方法

    model = resnet18(pretrained=False) # 主干提取网络
    model.load_state_dict(torch.load('./resnet18-5c106cde.pth'), strict=False)
    print(model)
    par = summary(model, (3, 224, 224), device='cpu')
    print(par)

    net = RFNet( model, 1, use_bn=True) # 输出类别 num_classes
    # print(model)
    input1 = torch.rand((1,3,256,256)) # 输入通道为1
    input1 = input1.expand(1,3,256,256) # 扩展为3通道

    print(input1.shape)
    input2 = torch.rand((1,1,256, 256))
    output = net(input1,input2)
    print(output.shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

方法二:修改字典参数

import torchvision.models as models
import torch
import torch.nn as nn
from torchsummary import summary

resnet18 = models.resnet18(pretrained=False)
resnet18.conv1= nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False)

# print(resnet18)
pretrained_dict = torch.load('./resnet/resnet18-5c106cde.pth')
# for k, v in pretrained_dict.items():
#     print(k)

x = torch.rand(64, 1, 7, 7)
pretrained_dict["conv1.weight"] = x

conv1 = pretrained_dict["conv1.weight"]
print(conv1.shape)
resnet18.load_state_dict(pretrained_dict)

# print(resnet18)
par = summary(resnet18, (1, 224, 224),device='cpu')
print(par)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号