赞
踩
在迁移学习中,要加载预训练模型,如果是torch内置的一些模型网上有很多的方法很简单,但是当加载自己训练完成的模型以后,如何解决最后连接层输入维度不一致的问题,看了好几个帖子都不成功,最后看了一下,既然是将参数param转换成dict,不如直接将不需要的删掉即可。
于是可以用如下方法
pretrained_params = torch.load(path)
net = YOUR_OWN_MODEL()
state_dict = pretrained_params.state_dict()
del state_dict[xxx] #需要被删掉的参数
net.load_state_dict(state_dict, strict=False)
简单有效!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。