当前位置:   article > 正文

Tensorflow2 加载别人模型权重失败时的一种解决方法:每一层分别设置值_valueerror: layer count mismatch when loading weig

valueerror: layer count mismatch when loading weights from file. model expec

有时自己建立模型导入别人模型参数时会因为版本等各个问题导致导入失败,这时可以一层一层的给网络赋值权重。

 本人在加载某个模型的权重通过load_weights方式加载时报错,但是我拥有别人已加载完权重的model。所以试了下手动一层一层的加载权重,结果还是可行的,就是比较麻烦。

tensorflow2的模型权重可以以set_weights方式进行设置权重。比如要改以下自己定义的模型每层加载权重:

  1. from tensorflow.keras import Sequential, layers, Model
  2. import numpy as np
  3. class SetWeights(Model):
  4. def __init__(self):
  5. super(SetWeights, self).__init__()
  6. self.c = layers.Conv2D(filters = 32,kernel_size=(3,3),strides=(1,1),padding='same')
  7. self.bn = layers.BatchNormalization()
  8. self.ac = layers.LeakyReLU(0.1)
  9. def call(self, inputs, training=None, mask=None):
  10. x = self.ac(self.bn(self.c(inputs)))
  11. return x

首先初始化模型,查看可训练参数结构,如下:

  1. model = SetWeights()
  2. model.build((None,32,32,3))
  3. # 查看所有可训练的参数
  4. # 以列表的方式保存,长度4,包含卷积层的w,b,BN的gamma,beta
  5. model.trainable_variables

通过以下方法设置权重

  1. # model.属性.set_weights方法
  2. # 需要注意的是传入的参数是一个list
  3. model.c.set_weights()

比如我们实例化两个model (model1相当于自己建立的模型,model2相当于拥有的别人已加载好的模型,因为结构顺序或者版本等原因,不能save_weights再load_weights加载别人的权重)

将model2的参数设置到model1中方法:  

  1. model1 = SetWeights()
  2. model1.build((None,32,32,3))
  3. model2 = SetWeights()
  4. model2.build((None,32,32,3))
  5. # 卷积层的权重设置(w、b)
  6. # 传入的list顺序是kernel、bias
  7. model1.c.set_weights([model2.trainable_variables[0].numpy(),model2.trainable_variables[1].numpy()])
  8. # BN层参数设置
  9. # 传入的list顺序是gamma、beta、moving_mean、moving_variance
  10. # 获得BN层每个参数长度
  11. bn_len = model2.trainable_variables[2].shape[0]
  12. model1.bn.set_weights([model2.trainable_variables[2].numpy(),model2.trainable_variables[3].numpy(),np.zeros((bn_len),dtype=np.float32),np.ones((bn_len),dtype=np.float32)])
  13. # 这样model1里的权重已经变成model2中的了

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/117901
推荐阅读
相关标签
  

闽ICP备14008679号