赞
踩
方法一:扩张1通道为3通道,利用torch.expand()方法
model = resnet18(pretrained=False) # 主干提取网络
model.load_state_dict(torch.load('./resnet18-5c106cde.pth'), strict=False)
print(model)
par = summary(model, (3, 224, 224), device='cpu')
print(par)
net = RFNet( model, 1, use_bn=True) # 输出类别 num_classes
# print(model)
input1 = torch.rand((1,3,256,256)) # 输入通道为1
input1 = input1.expand(1,3,256,256) # 扩展为3通道
print(input1.shape)
input2 = torch.rand((1,1,256, 256))
output = net(input1,input2)
print(output.shape)
方法二:修改字典参数
import torchvision.models as models import torch import torch.nn as nn from torchsummary import summary resnet18 = models.resnet18(pretrained=False) resnet18.conv1= nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,bias=False) # print(resnet18) pretrained_dict = torch.load('./resnet/resnet18-5c106cde.pth') # for k, v in pretrained_dict.items(): # print(k) x = torch.rand(64, 1, 7, 7) pretrained_dict["conv1.weight"] = x conv1 = pretrained_dict["conv1.weight"] print(conv1.shape) resnet18.load_state_dict(pretrained_dict) # print(resnet18) par = summary(resnet18, (1, 224, 224),device='cpu') print(par)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。