当前位置:   article > 正文

【TensorFlow深度学习】模型持久化:保存与加载的最佳实践

【TensorFlow深度学习】模型持久化:保存与加载的最佳实践

在深度学习的实践中,模型的保存与加载是一项至关重要的技能,它不仅能够帮助我们保留珍贵的训练成果,还便于模型的迁移、部署及后续的调优。本文将以《TensorFlow 2.0深度学习算法实战教材》为依据,深入探讨Keras框架下模型持久化的三种主要方式,分别是张量方式、网络方式(HDF5文件)、以及SavedModel方式,并通过实际代码示例展现其应用。

张量方式:轻量级参数保存

当我们拥有模型的源代码,并且希望仅保存模型参数时,张量方式最为合适。这种方法仅需调用Model.save_weights()方法,即可将模型的参数存储为文件,如.ckpt格式。以下代码片段展示了MNIST模型的参数保存与加载流程:

# 保存模型参数
network.save_weights('weights.ckpt')
print('saved weights.')

# 删除网络对象,模拟重新初始化场景
del network

# 重新创建网络结构
network = Sequential([
    layers.Dense(256, activation='relu'),
    layers.Dense(128, activation='relu'),
    layers.Dense(64, activation='relu'),
    layers.Dense(32, activation='relu'),
    layers.Dense(10)
])

# 编译网络
network.compile(optimizer=optimizers.Adam(lr=0.01),
              loss=tf.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# 加载模型参数
network.load_weights('weights.ckpt')
print('loaded weights!')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

请注意,张量方式要求网络结构必须完全相同才能成功加载参数,因此适用于结构固定的场景。

网络方式:结构与参数一体化保存

对于那些不想或无法保持模型源代码一致性的场景,可以采用网络方式。通过Model.save()方法,模型的结构和参数会被打包进一个.h5文件中,之后只需调用tf.keras.models.load_model()即可复原整个模型。下面展示了相应的代码示例:

# 保存模型结构与参数
network.save('model.h5')
print('saved total model.')

# 删除网络对象
del network

# 从文件恢复模型
network = tf.keras.models.load_model('model.h5')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

这种方式的优势在于,即使没有原始代码,只要有了.h5文件,就能重建模型,适合模型分享或部署。

SavedModel方式:跨平台部署的首选

SavedModel是TensorFlow针对模型部署推出的标准格式,它不仅包含模型结构和参数,还支持图优化和签名,非常适合生产环境。通过tf.keras.experimental.export_saved_model()方法,模型可以被保存为SavedModel格式。以下是保存与加载的代码示例:

# 保存为SavedModel格式
tf.keras.experimental.export_saved_model(network, 'model-savedmodel')
print('export saved model.')

# 删除网络对象
del network

# 从SavedModel文件加载模型
network = tf.keras.experimental.load_from_saved_model('model-savedmodel')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

SavedModel支持多种平台,如服务器、移动设备甚至是Web,是模型部署时的优选方案。

小结

模型持久化是深度学习项目中不可或缺的一环,正确选择保存与加载方式能够极大提升工作效率。张量方式适合结构确定且需频繁调整参数的场景;网络方式(.h5)便于模型共享与快速复现;而SavedModel则在模型部署、跨平台应用中展现出独特优势。掌握这三种方法,将使你的深度学习之旅更加顺畅。无论是在科研探索还是产品开发中,合理的模型管理都是确保工作连续性和高效迭代的关键。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/670040
推荐阅读
相关标签
  

闽ICP备14008679号