赞
踩
在训练神经网络时,如果epochs设置的过多,导致最终结束时测试集上模型的准确率比较低,而我们却想保存准确率最高时候的模型参数,这就需要用到Early Stopping以及ModelCheckpoint。
EarlyStopping是用于提前停止训练的callbacks,callbacks用于指定在每个epoch开始和结束的时候进行哪种特定操作。简而言之,就是可以达到当测试集上的loss不再减小(即减小的程度小于某个阈值)的时候停止继续训练。
1.将数据分为训练集和测试集
2.每个epoch结束后(或每N个epoch后): 在测试集上获取测试结果,随着epoch的增加,如果在测试集上发现测试误差上升,则停止训练;
3.将停止之后的权重作为网络的最终参数。
这儿就有一个疑惑,在平常模型训练时,会发现模型的loss值有时会出现降低再上升再下降的情况,难道只要再上升的时候就要停止嘛?上升之后再下降有可能会得到更低的loss值,那么如果只要上升就停止的话,就会得不偿失。现实肯定不是这样的!不能根据一两次的连续降低就判断不再提高。一般的做法是,在训练的过程中,记录到目前为止最好的测试集精度,当连续10次epoch(或者更多次)没达到最佳精度时,则可以认为精度不再提高了。
看图直观感受一下:
优点:只运行一次梯度下降,我们就可以找出w的较小值,中间值和较大值。而无需尝试L2正则化超级参数lambda的很多值。
缺点:不能独立地处理以上两个问题,使得要考虑的东西变得复杂
tf.keras.callbacks.EarlyStopping(
monitor="acc",
min_delta=0,
patience=0,
verbose=0,
mode="max",
baseline=None,
restore_best_weights=False,
)
1.monitor: 监控的数据接口,有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情况下如果有验证集,就用’val_acc’或者’val_loss’。
2.mode: 就’auto’, ‘min’, ‘,max’三个可能。如果知道是要上升还是下降,建议设置一下。例如监控的是’acc’,那么就设置为’max’。
3.min_delta:增大或减小的阈值,只有大于这个部分才算作改善(监控的数据不同,变大变小就不确定)。这个值的大小取决于monitor,也反映了你的容忍程度。
4.patience:能够容忍多少个epoch内都没有改善。patience的大小和learning rate直接相关。在learning rate设定的情况下,前期先训练几次观察抖动的epoch number,patience设置的值应当稍大于epoch number。在learning rate变化的情况下,建议要略小于最大的抖动epoch number。
5.baseline:监控数据的基线值,如果在训练过程中,模型训练结果相比于基线值没有什么改善的话,就停止训练。
函数原型:
tf.keras.callbacks.ModelCheckpoint(filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
period=1)
1.filename:字符串,保存模型的路径,filepath可以是格式化的字符串,里面的占位符将会被epoch值和传入on_epoch_end的logs关键字所填入。
例如:filepath = “weights_{epoch:03d}-{val_loss:.4f}.h5”,则会生成对应epoch和测试集loss的多个文件。
2.monitor:需要监视的值,通常为:val_acc 、 val_loss 、 acc 、 loss四种。
3.verbose:信息展示模式,0或1。为1表示输出epoch模型保存信息,默认为0表示不输出该信息。
4.save_best_only:当设置为True时,将只保存在测试集上性能最好的模型。
5.mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
6.save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)。
7.period:CheckPoint之间的间隔的epoch数。
from tensorflow.keras.callbacks import ModelCheckpoint, Callback, EarlyStopping
earlystopper = EarlyStopping(
monitor='loss',
patience=1,
verbose=1,
mode = 'min')
checkpointer = ModelCheckpoint('best_model.h5',
monitor='val_accuracy',
verbose=0,
save_best_only=True,
save_weights_only=True,
mode = 'max')
train_model = model.fit(train_ds,
epochs=epochs,
validation_data=test_ds,
callbacks=[earlystopper, checkpointer]#<-看这儿)
努力加油a啊
参考链接:
https://blog.csdn.net/zwqjoy/article/details/86677030
https://blog.csdn.net/zengNLP/article/details/94589469
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。