赞
踩
前阵子做了个以图搜图特征编码模型啊。(详情看上一篇)
但是由于图库数据较大(上亿数据),所以2048维的特征编码存储量太大,一个特征8KB,用户并发起来服务器也够呛,而且java那边相似度计算也慢。
由于图库里面的图形都比较简单,老大觉得512够用了,要我修改网络输出到512维的特征编码。
但是模型网络那边提供的预训练模型,Resnet50只有输出层是2048维的。我们也不想换Resnet18(可能会较大的降低精度)。
因此我只能够再Resnet50的预训练权重包上面下手。
首先,我先看我的网络结构。(在这个位置:cirtorch/networks/imageretrievalnet.py)
根据我的网络初始化设置,我用的预训练包是Resnet50_w,也就是再Resnet50后面加了一层全连接网络。输出1 * 2048的特征编码。(如果可视化的看网络结构可以用https://netron.app 如下图所示。)
因此我只要修改全连接层的输入输出就行,输入保持不变2048(因为Resnet最后一层输出是 2048),输出改为512。
if whitening:
whiten = nn.Linear(2048, dim, bias=True)
# 这里再上面设置了dim = 512
修改完我们的网络结构后,我们就要用这个网络加载权重,可是我们的权重包就跟上图一样,最后一层搭配不上。那我们需要怎么加载网络呢。
如下图所示,我们希望我们的模型权重最后一层长这样。
那就好办了。只需要再加载模型后,将weights的第一维砍掉3/4,同样bias也砍掉3/4。就可以完成加载权重了。如下代码。
temp_state = torch.load('weights/resnet50_dim2048.pth.tar')
temp_state['state_dict']['whiten.weight'] = temp_state['state_dict']['whiten.weight'][0::4, ::]
temp_state['state_dict']['whiten.bias'] = temp_state['state_dict']['whiten.bias'][0::4]
修改完后,要记得用新的预训练模型重新进行训练。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。