赞
踩
在深度学习的实践中,模型的保存与加载是一项至关重要的技能,它不仅能够帮助我们保留珍贵的训练成果,还便于模型的迁移、部署及后续的调优。本文将以《TensorFlow 2.0深度学习算法实战教材》为依据,深入探讨Keras框架下模型持久化的三种主要方式,分别是张量方式、网络方式(HDF5文件)、以及SavedModel方式,并通过实际代码示例展现其应用。
当我们拥有模型的源代码,并且希望仅保存模型参数时,张量方式最为合适。这种方法仅需调用Model.save_weights()
方法,即可将模型的参数存储为文件,如.ckpt
格式。以下代码片段展示了MNIST模型的参数保存与加载流程:
# 保存模型参数
network.save_weights('weights.ckpt')
print('saved weights.')
# 删除网络对象,模拟重新初始化场景
del network
# 重新创建网络结构
network = Sequential([
layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)
])
# 编译网络
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 加载模型参数
network.load_weights('weights.ckpt')
print('loaded weights!')
请注意,张量方式要求网络结构必须完全相同才能成功加载参数,因此适用于结构固定的场景。
对于那些不想或无法保持模型源代码一致性的场景,可以采用网络方式。通过Model.save()
方法,模型的结构和参数会被打包进一个.h5
文件中,之后只需调用tf.keras.models.load_model()
即可复原整个模型。下面展示了相应的代码示例:
# 保存模型结构与参数
network.save('model.h5')
print('saved total model.')
# 删除网络对象
del network
# 从文件恢复模型
network = tf.keras.models.load_model('model.h5')
这种方式的优势在于,即使没有原始代码,只要有了.h5
文件,就能重建模型,适合模型分享或部署。
SavedModel是TensorFlow针对模型部署推出的标准格式,它不仅包含模型结构和参数,还支持图优化和签名,非常适合生产环境。通过tf.keras.experimental.export_saved_model()
方法,模型可以被保存为SavedModel格式。以下是保存与加载的代码示例:
# 保存为SavedModel格式
tf.keras.experimental.export_saved_model(network, 'model-savedmodel')
print('export saved model.')
# 删除网络对象
del network
# 从SavedModel文件加载模型
network = tf.keras.experimental.load_from_saved_model('model-savedmodel')
SavedModel支持多种平台,如服务器、移动设备甚至是Web,是模型部署时的优选方案。
模型持久化是深度学习项目中不可或缺的一环,正确选择保存与加载方式能够极大提升工作效率。张量方式适合结构确定且需频繁调整参数的场景;网络方式(.h5
)便于模型共享与快速复现;而SavedModel则在模型部署、跨平台应用中展现出独特优势。掌握这三种方法,将使你的深度学习之旅更加顺畅。无论是在科研探索还是产品开发中,合理的模型管理都是确保工作连续性和高效迭代的关键。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。