当前位置:   article > 正文

读取h5文件中模型的权重值_读取h5权重的参数量

读取h5权重的参数量

在keras和tensorflow2中,模型或者模型的权重以h5文件保存。怎么单独读取保存模型或者模型中的权重值呢?这是这篇文章中讨论的问题。

首先我们从网上下载vgg16的模型文件,重命名为‘vgg16’。下载地址为:VGG16权重

首先简单介绍一下h5文件:

h5文件即HDF5文件(Hierarchical Data Format Version 5( HDF5): 层次性数据格式第五版),是一种存储相同类型数值的大数组的机制,适用于可被层次性组织且数据集需要被元数据标记的数据模型,在python中常用的接口模块为 h5py。

HDF5 三大要素:

hdf5 files:能够存储两类数据对象 dataset 和 group 的容器,其操作类似 python 标准的文件操作;File 实例对象本身就是一个组,以 / 为名,是遍历文件的入口。

dataset(array-like):可类比为 Numpy 数组,每个数据集都有一个名字(name)、形状(shape) 和类型(dtype),支持切片操作。

group(folder-like):可以类比为 字典,它是一种像文件夹一样的容器,group 中可以存放 dataset 或者其他的 group,键就是组成员的名称,值就是组成员对象本身(组或者数据集)。HDF5 group在组织数据上像文件的目录,但在操作上HDF5 group为字典

总之,一个HDF5文件是一种存放两类对象的容器:dataset和group。

 Dataset是类似于数组的数据集,而group是类似文件夹一样的容器,存放dataset和其他group。在使用h5py的时候需要牢记一句话:groups类比词典,dataset类比Numpy中的数组。HDF5的dataset虽然与Numpy的数组在接口上很相近,但是支持更多对外透明的存储特征,如数据压缩,误差检测,分块传输。因为h5文件类似python的词典对象,因此我们可以查看所有的键值,示例如下:

  1. import h5py
  2. f = h5py.File('vgg16.h5', 'r')
  3. print('f.keys():', f.keys())
  4. print('f.values():', f.values())
  5. print('f.items():', f.items())
  6. print('f.attrs:', f.attrs)
  7. print('f.attrs.items():', f.attrs.items())
  8. print('f.attrs.keys():', f.attrs.keys())
  9. print("f.attrs['layer_names']:", f.attrs['layer_names'])
  10. print("f['block1_conv1'].attrs.keys():", f['block1_conv1'].attrs.keys())
  11. print("f['block1_conv1'].attrs['weight_names']:", f['block1_conv1'].attrs['weight_names'])
  12. print("f['block1_conv1/block1_conv1_W_1:0']:", f['block1_conv1/block1_conv1_W_1:0'])

输出如下:

  1. f.keys(): <KeysViewHDF5 ['block1_conv1', 'block1_conv2', 'block1_pool', 'block2_conv1', 'block2_conv2', 'block2_pool', 'block3_conv1', 'block3_conv2', 'block3_conv3', 'block3_pool', 'block4_conv1', 'block4_conv2', 'block4_conv3', 'block4_pool', 'block5_conv1', 'block5_conv2', 'block5_conv3', 'block5_pool', 'fc1', 'fc2', 'flatten', 'predictions']>
  2. f.values(): ValuesViewHDF5(<HDF5 file "vgg16.h5" (mode r)>)
  3. f.items(): ItemsViewHDF5(<HDF5 file "vgg16.h5" (mode r)>)
  4. f.attrs: <Attributes of HDF5 object at 140703882956288>
  5. f.attrs.items(): ItemsViewHDF5(<Attributes of HDF5 object at 140703882956288>)
  6. f.attrs.keys(): <KeysViewHDF5 ['layer_names']>
  7. f.attrs['layer_names']: [b'block1_conv1' b'block1_conv2' b'block1_pool' b'block2_conv1'
  8. b'block2_conv2' b'block2_pool' b'block3_conv1' b'block3_conv2'
  9. b'block3_conv3' b'block3_pool' b'block4_conv1' b'block4_conv2'
  10. b'block4_conv3' b'block4_pool' b'block5_conv1' b'block5_conv2'
  11. b'block5_conv3' b'block5_pool' b'flatten' b'fc1' b'fc2' b'predictions']
  12. f['block1_conv1'].attrs.keys(): <KeysViewHDF5 ['weight_names']>
  13. f['block1_conv1'].attrs['weight_names']: [b'block1_conv1_W_1:0' b'block1_conv1_b_1:0']
  14. f['block1_conv1/block1_conv1_W_1:0']: <HDF5 dataset "block1_conv1_W_1:0": shape (3, 3, 3, 64), type "<f4">

h5文件中的权重保存在HDF5 dataset中,怎么读取呢?正如上面所讲,因为h5文件中的dataset是类似于数组的数据集,那我们就可以用读取数组的方法进行读取了!

  1. print(f['block1_conv1/block1_conv1_W_1:0'][:])
  2. print(f['block1_conv1/block1_conv1_W_1:0'].value) #会报错

输出如下:

  1. [[[[ 4.29470569e-01 1.17273867e-01 3.40129584e-02 ... -1.32241577e-01
  2. -5.33475243e-02 7.57738389e-03]
  3. [ 5.50379455e-01 2.08774377e-02 9.88311544e-02 ... -8.48205537e-02
  4. -5.11389151e-02 3.74943428e-02]
  5. [ 4.80015397e-01 -1.72696680e-01 3.75577137e-02 ... -1.27135560e-01
  6. -5.02991639e-02 3.48965675e-02]]
  7. ......
  8. [[-3.50096762e-01 1.38710454e-01 -1.25339806e-01 ... -1.53092295e-01
  9. -1.39917329e-01 -2.65075237e-01]
  10. [-4.85030204e-01 4.23195846e-02 -1.12076312e-01 ... -1.18306056e-01
  11. -1.67058021e-01 -3.22241962e-01]
  12. [-4.18516338e-01 -1.57048807e-01 -1.49133086e-01 ... -1.56839803e-01
  13. -1.42874300e-01 -2.69694626e-01]]]]
  14. Traceback (most recent call last):
  15. File "/home/t.py", line 14, in <module>
  16. print(f['block1_conv1/block1_conv1_W_1:0'].value)
  17. AttributeError: 'Dataset' object has no attribute 'value'

网上的教程有用.value的方法读取dataset中的值的,但是在我这行不通,不知道是不是我的问题。。。所以我只能用[:]像读取数组的值一样读取dataset中的值了。


既然现在我们知道了h5文件中的数据存放形式以及读取方法,那就开始读取vgg16.h5中的权重值吧!先查看一下上代码:

  1. import h5py
  2. f = h5py.File('vgg16.h5','r')
  3. for k, v in f.attrs.items():
  4. print('k:', k)
  5. print('v:', v)
  6. for k1, v1 in f.items():
  7. print('k1:', k1)
  8. print('v1:', v1)

输出:

  1. k: layer_names
  2. v: [b'block1_conv1' b'block1_conv2' b'block1_pool' b'block2_conv1'
  3. b'block2_conv2' b'block2_pool' b'block3_conv1' b'block3_conv2'
  4. b'block3_conv3' b'block3_pool' b'block4_conv1' b'block4_conv2'
  5. b'block4_conv3' b'block4_pool' b'block5_conv1' b'block5_conv2'
  6. b'block5_conv3' b'block5_pool' b'flatten' b'fc1' b'fc2' b'predictions']
  7. k1: block1_conv1
  8. v1: <HDF5 group "/block1_conv1" (2 members)>
  9. k1: block1_conv2
  10. v1: <HDF5 group "/block1_conv2" (2 members)>
  11. k1: block1_pool
  12. v1: <HDF5 group "/block1_pool" (0 members)>
  13. k1: block2_conv1
  14. v1: <HDF5 group "/block2_conv1" (2 members)>
  15. ......(省略)

从上可以看出从f.attrs.items()只能读出模型中各个层或者各个模块的名称,所以要从f.items()中才可以继续读到模型的权重值!

  1. for k1, v1 in f.items():
  2. print('k1:', k1)
  3. print('v1:', v1)
  4. for k2, v2 in v1.items():
  5. print('k2:', k2)
  6. print('v2:', v2)
  7. print('v2:', v2[:].shape)
  8. print('v2:', v2[:])

输出:

  1. k1: block1_conv1
  2. v1: <HDF5 group "/block1_conv1" (2 members)>
  3. k2: block1_conv1_W_1:0
  4. v2: <HDF5 dataset "block1_conv1_W_1:0": shape (3, 3, 3, 64), type "<f4">
  5. v2: (3, 3, 3, 64)
  6. v2: [[[[ 4.29470569e-01 1.17273867e-01 3.40129584e-02 ... -1.32241577e-01
  7. -5.33475243e-02 7.57738389e-03]
  8. [ 5.50379455e-01 2.08774377e-02 9.88311544e-02 ... -8.48205537e-02
  9. -5.11389151e-02 3.74943428e-02]
  10. [ 4.80015397e-01 -1.72696680e-01 3.75577137e-02 ... -1.27135560e-01
  11. -5.02991639e-02 3.48965675e-02]
  12. (。。。省略)

从上可以看出,先遍历字典f.items(),得到h5文件中的group(v1),然后在遍历group字典v1.items(),就可以得到group中的dataset(v2),最后就可以通过v2[:]读取其权重值了!如果在遍历group字典v1.attrs.items(),就可以得到group中的weight_names,如下所示:

  1. for k1, v1 in f.items():
  2. print('k1:', k1)
  3. print('v1:', v1)
  4. for k2, v2 in v1.attrs.items():
  5. print('k2:', k2)
  6. print('v2:', v2)
  7. print('v2:', v2[:].shape)
  8. print('v2:', v2[:])

输出:

  1. k1: block1_conv1
  2. v1: <HDF5 group "/block1_conv1" (2 members)>
  3. k2: weight_names
  4. v2: [b'block1_conv1_W_1:0' b'block1_conv1_b_1:0']
  5. v2: (2,)
  6. v2: [b'block1_conv1_W_1:0' b'block1_conv1_b_1:0']
  7. ......(省略)

所以为了读取权重值,要用.items(),而不是.attrs.items()!因为这样得到的时权重值的名字

但是有时候为了得到权重值的名字和其值,则需要如下代码:

  1. for k1, v1 in f.items():
  2. print('k1:', k1)
  3. print('v1:', v1)
  4. for k2, v2 in v1.attrs.items():
  5. print('k2:', k2)
  6. print('v2:', v2)
  7. for i in v2:
  8. name = k1+'/'+str(i, encoding="utf-8")
  9. print('name:', name)
  10. print('f[name]:', f[name])
  11. print('f[name].shap:', f[name].shape)
  12. print('f[name][:]:', f[name][:])

输出:

  1. k1: block1_conv1
  2. v1: <HDF5 group "/block1_conv1" (2 members)>
  3. k2: weight_names
  4. v2: [b'block1_conv1_W_1:0' b'block1_conv1_b_1:0']
  5. name: block1_conv1/block1_conv1_W_1:0
  6. f[name]: <HDF5 dataset "block1_conv1_W_1:0": shape (3, 3, 3, 64), type "<f4">
  7. f[name].shap: (3, 3, 3, 64)
  8. f[name][:]: [[[[ 4.29470569e-01 1.17273867e-01 3.40129584e-02 ... -1.32241577e-01
  9. -5.33475243e-02 7.57738389e-03]
  10. [ 5.50379455e-01 2.08774377e-02 9.88311544e-02 ... -8.48205537e-02
  11. -5.11389151e-02 3.74943428e-02]
  12. [ 4.80015397e-01 -1.72696680e-01 3.75577137e-02 ... -1.27135560e-01
  13. -5.02991639e-02 3.48965675e-02]]
  14. (。。。。省略)

再介绍另一个可以读取h5文件的deepdish模块,还是用‘vgg16.h'文件为例。代码示例如下:

  1. import deepdish as dd
  2. p = dd.io.load('vgg16.h5')
  3. print(type(p))
  4. for k,v in p.items():
  5. print('k:', k)
  6. print('v:', v)

输出:

  1. <class 'dict'>
  2. k: block1_conv1
  3. v: {'block1_conv1_W_1:0': array([[[[ 4.29470569e-01, 1.17273867e-01, 3.40129584e-02, ...,
  4. -1.32241577e-01, -5.33475243e-02, 7.57738389e-03],
  5. [ 5.50379455e-01, 2.08774377e-02, 9.88311544e-02, ...,
  6. -8.48205537e-02, -5.11389151e-02, 3.74943428e-02],
  7. [ 4.80015397e-01, -1.72696680e-01, 3.75577137e-02, ...,
  8. -1.27135560e-01, -5.02991639e-02, 3.48965675e-02]],
  9. ............}
  10. ......(省略)

从以上代码可以看出通过deepdish.io.load()方法读取’vgg16.h'文件后,就直接j将vgg16.h文件的内容以字典的形式返回!


那返回的字典中的值又是什么类型呢?

  1. import deepdish as dd
  2. p = dd.io.load('vgg16.h5')
  3. print(type(p))
  4. for k,v in p.items():
  5. print('k:', k)
  6. print('type(v):', type(v))

输出:

  1. <class 'dict'>
  2. k: block1_conv1
  3. type(v): <class 'dict'>
  4. k: block1_conv2
  5. type(v): <class 'dict'>
  6. k: block1_pool
  7. type(v): <class 'dict'>
  8. ......(省略)
  9. k: predictions
  10. type(v): <class 'dict'>
  11. k: layer_names
  12. type(v): <class 'numpy.ndarray'>

通过以上代码可以得知,除了最后一个k(layer_names)对应的v的类型不是字典(dict)外,其余全是字典(dict),那么我们还可以通过遍历v这个字典(dict),直接得到每一个权重值!

  1. import deepdish as dd
  2. p = dd.io.load('vgg16.h5')
  3. for k1, v1 in p.items():
  4. print('k1:', k1)
  5. print('v1:', v1)
  6. if k1 != 'layer_names':
  7. for k2, v2 in v1.items():
  8. print('k2:', k2)
  9. print('v2:', v2)
  10. print('name:', k1+'/'+k2)

输出:

  1. k1: block1_conv1
  2. v1: {'block1_conv1_W_1:0': array([[[[ 4.29470569e-01, 1.17273867e-01, 3.40129584e-02, ...,
  3. -1.32241577e-01, -5.33475243e-02, 7.57738389e-03],
  4. [ 5.50379455e-01, 2.08774377e-02, 9.88311544e-02, ...,
  5. -8.48205537e-02, -5.11389151e-02, 3.74943428e-02],
  6. [ 4.80015397e-01, -1.72696680e-01, 3.75577137e-02, ...,
  7. -1.27135560e-01, -5.02991639e-02, 3.48965675e-02]],
  8. [-4.18516338e-01, -1.57048807e-01, -1.49133086e-01, ...,
  9. -1.56839803e-01, -1.42874300e-01, -2.69694626e-01]]]],
  10. dtype=float32), 'block1_conv1_b_1:0': array([ 0.73429835, 0.09340367, 0.06775674, 0.8862966 , 0.25994542,
  11. 0.66426694, -0.01582893, 0.3249065 , 0.68600726, 0.06247932,
  12. 0.58156496, 0.2361475 , 0.69694996, 0.19451167, 0.4858922 ,
  13. 0.09850298, 0.3803252 , 0.66880196, 0.4015123 , 0.90510356,
  14. 0.43166816, 1.302014 , 0.5306885 , 0.48993504], dtype=float32),
  15. ......(省略)

通过以上代码的示例,deepdish模块也可以获取h5文件中模型的权重值以及对应的名字,好像还比使用h5py模块简单!

至此,我们已经可以将h5文件中的权重名称和权重值完全读取出来了!

更多内容可以关注个人做着玩的微信公众号。。。

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/772698
推荐阅读
相关标签
  

闽ICP备14008679号