赞
踩
应用场景如下:
注意保存路径的设置是./weight/
后面的/
千万不要少,要不然可能会出现保存路径的问题。
在前面模型已经compile
完毕,在训练结束后,保存一下模型的参数。保存的文件夹是save_path
的文件夹,要提前设置好,并且是空的,保存完毕后在save_path
文件夹下会生成以下三个文件
file_path
可以设置为主文件mnist.py
文件夹下的空文件夹,比如上图中的weight
用于保存模型参数。
在使用的时候,将模型结构设置为相同后,直接加载参数:Model.load_weights(file_path)
即可使用,但是注意模型结构要相同且已经经过compile
才可以,要不然会报错。等于只是省略了model.fit
的过程.
保存模型
model.save(‘net_model.h5’)
模型加载
new_model=tf.keras.models.load_model(‘net_model.h5’)
Keras使用HDF5标准提供基本保存格式,出于我们的目的,可以将保存的模型视为单个二进制blob。
保存完整的模型非常有用,使我们可以在TensorFlow.js(HDF5, Saved Model)
中加载它们,然后在Web浏览器中训练和运行它们,或者使用TensorFlow Lite(HDF5, Saved Model)
将它们转换为在移动设备上运行。
所以,我们保存整个模型的时候,保存文件的后缀一般都是.h5
应用场景如下:
# 训练模型
save_path = 'net_model.h5'
if os.path.exists(save_path) == False:
# 优化器
adam_optimizer = tf.keras.optimizers.Adam(learning_rate, )
# 编译模型
model.compile(optimizer=adam_optimizer,
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['acc'])
# 模型开始训练
history = model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs,
validation_data=(x_test, y_test))
model.save('net_model.h5')
else:
model = tf.keras.models.load_model('net_model.h5')
可以看到在与主文件相同的文件夹下产生了一个以.h5
结尾的文件net_model.h5
,这个就是保存的完整的模型及参数。一旦这个文件存在的话,那么就可以直接加载这个模型及其参数,不需要训练。如下第二个图,可以加载到已经存下来的网络结构,和上轮训练的最后一轮的参数。
在这里补充一下,判断一个文件是否存在的方法 以及 判断一个文件夹是否为空的方法:
os.path.exists(test_file.txt)=False
说明该文件夹不存在if os.path.exists(test_file.txt) == False # 该文件不存在的情况下
print(‘目标文件不存在’)
len(os.listdir(tar_dir)==0
说明改文件夹下为空if len(os.listdir(tar_dir)) == 0: # 目标文件夹内容为空的情况下
print(“目标文件夹为空”)
tf.keras.callbacks.ModelCheckpoint(参数如下)
有时候,我们需要保存训练过程中最好的结果,或者想先暂停训练后续再继续训练,这就需要用到checkpoint保存模型了。
save_best_only=True
保存最好的参数,默认False保存最后一个epoch的参数
save_weights_only=True
只保存参数,默认False 保存整个模型
save_path = './checkpoint'
if len(os.listdir(save_path)) == 0:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=save_path+'/train.ckpt',
verbose=1,
save_best_only=True,
save_weights_only=True,
period=1
)
history = model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs,
validation_data=(x_test, y_test),callbacks=[cp_callback])
plt.plot(history.epoch, history.history.get('acc'), label='acc')
plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
plt.legend()
plt.show()
else:
model.load_weights(save_path+'/train.ckpt')
在没有文件夹或者文件夹为空(尚未保存模型的时候)开始训练,训练过程如下,会自动保存更新损失值较低的,即模型效果更好的参数如下图,如果loss没下降,是不会更新的。
保存完文件夹下内容如下:对比可发现在save_weights_only=True
d的情况下保存在文件夹里的 三个文件类型 与第一种方式相同。
如果不设置save_weights_only=True,
那么保存的是一整个模型,文件格式如下图第一个所示,用法跟第二个保存完整模型结构类似。加载的时候,直接加载模型即可。加载模型路径是save_path+'train.ckpt'
save_path = './checkpoint/'
if len(os.listdir(save_path)) == 0:
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=save_path+'/train.ckpt',
verbose=1,
save_best_only=True,
period=1
)
history = model.fit(x=x_train, y=y_train, batch_size=batch_size, epochs=num_epochs,
validation_data=(x_test, y_test),callbacks=[cp_callback])
plt.plot(history.epoch, history.history.get('acc'), label='acc')
plt.plot(history.epoch, history.history.get('val_acc'), label='val_acc')
plt.legend()
plt.show()
else:
model = tf.keras.models.load_model(save_path+'train.ckpt')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。