赞
踩
//这里的想法是:拿出预训练权重(字典)的key和value,通过获取我们自己网络中与预训练权重中网络层名称一样的的层,拿到相同个数的网络层,删除不一致的
pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}
missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False) //将特征提取层的权重送进网络,这里的strict设置为False后,就不用预训练权重的网络结构和我们自己的网络完全key值一致
而第三种方法,在创建网络时候,更改最后的全连接层节点个数,直接net.load\_state\_dict()方法载入会报错的
net = MobileNetV2()
net.load_state_dict(torch.load(pre_trained_pth), strict=False)
//方法三:
in_channel = net.fc.in_feacture
net.fc = nn.Linear(in_channel, 5) //这里为什么是.fc,道理同下,下面具体讲解了
>
> 这里另外说一下为什么删除预训练权重中全连接层的参数,上方代码的判断语句中必须是“fc”???
> 其实不然,取决于你搭建网络时类变量名称的定义,可以通过来查看每个网络层的名称:
>
>
>
net = MobileNe
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。