Press "Enter" to skip to content

轻松理解Keras回调

本站内容均来自兴趣收集,如不慎侵害的您的相关权益,请留言告知,我们将尽快删除.谢谢.

 

随着计算机处理能力的提高,人工智能模型的训练时间并没有缩短,主要是人们对模型精确度要求越来越高。为了提升模型精度,人们设计出越来越复杂的深度神经网络模型,喂入越来越海量的数据,导致训练模型也耗时越来越长。这就如同PC产业,虽然CPU遵从摩尔定律,速度越来越快,但由于软件复杂度的提升,我们并没有感觉计算机运行速度有显着提升,反而陷入需要不断升级电脑硬件的怪圈。

 

不知道大家有没有这种经历,准备数据,选择好模型,启动训练,训练了一天之后,却发现效果不理想。这个时候怎幺办?通常调整几个超参数,重新训练,这样折腾几个来回,可能一个星期,甚至一个月的时间就过去了。如果缺少反馈,训练深度学习模型就如同开车没有刹车一样。

 

这个时候,就需要了解训练中的内部状态以及模型的一些信息,在Keras框架中,回调就能起这样的作用。在本文中,我将介绍如何使用Keras回调(如ModelCheckpoint和EarlyStopping)监控和改进深度学习模型。

 

什幺是回调

 

Keras文档给出的定义为:

 

回调是在训练过程的特定阶段调用的一组函数,可以使用回调来获取训练期间内部状态和模型统计信息的视图。

 

你可以传递一个回调列表,同时获取多种训练期间的内部状态,keras框架将在训练的各个阶段回调相关方法。如果你希望在每个训练的epoch自动执行某些任务,比如保存模型检查点(checkpoint),或者希望控制训练过程,比如达到一定的准确度时停止训练,可以定义回调来做到。

 

keras内置的回调很多,我们也可以自行实现回调类,下面先深入探讨一些比较常用的回调函数,然后再谈谈如何自定义回调。

 

EarlyStopping

 

从字面上理解, EarlyStopping 就是提前终止训练,主要是为了防止过拟合。过拟合是机器学习从业者的噩梦,简单说,就是在训练数据集上精度很高,但在测试数据集上精度很低。解决过拟合有多种手段,有时还需要多种手段并用,其中一种方法是尽早终止训练过程。 EarlyStopping 函数有好几种度量参数,通过修改这些参数,可以控制合适的时机停止训练过程。下面是一些相关度量参数:

 

monitor:

监控的度量指标,比如:

acc, val_acc, loss和val_loss等

 

min_delta:

监控值的最小变化。

例如,min_delta = 1表示如果监视值的绝对值变化小于1,则将停止训练过程

 

patience:

没有改善的epoch数,如果过了数个epoch之后结果没有改善,训练将停止

 

restore_best_weights:

如果要在停止后保存最佳权重,请将此参数设置为True

 

下面的代码示例将定义一个跟踪val_loss值的EarlyStopping函数,如果在3个epoch后val_loss没有变化,则停止训练,并在训练停止后保存最佳权重:

 

from keras.callbacks import EarlyStopping
earlystop = EarlyStopping(monitor = 'val_loss',
                          min_delta = 0,
                          patience = 3,
                          verbose = 1,
                          restore_best_weights = True)

 

ModelCheckpoint

 

此回调用于在训练周期中保存模型检查点。保存检查点的作用在于保存训练中间的模型,下次在训练时,可以加载模型,而无需重新训练,减少训练时间。它有以一些相关参数:

 

filepath:

要保存模型的文件路径

 

monitor:

监控的度量指标,比如:

acc, val_acc, loss和val_loss等

 

save_best_only:

如果您不想最新的最佳模型被覆盖,请将此值设置为True

 

save_weights_only: 如果设为True,将只保存模型权重

 

mode:

auto,min或max。

例如,如果监控的度量指标是val_loss,并且想要最小化它,则设置mode =’min’。

 

period:

检查点之间的间隔(epoch数)。

 

示例:

 

from keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(filepath,
                             monitor='val_loss',
                             mode='min',
                             save_best_only=True,
                             verbose=1)

 

LearningRateScheduler

 

在深度学习中,学习率的选择也是一件让人头疼的事情,值选择小了,可能会收敛缓慢,值选大了,可能会导致震荡,无法到达局部最优点。后来专家们设计出一种自适应的学习率,比如在训练开始阶段,选择比较大的学习率值,加速收敛,训练一段时间之后,选择小的学习率值,防止震荡。 LearningRateScheduler 用于定义学习率的变化策略,参数如下:

 

schedule:

一个函数,以epoch数(整数,从0开始计数)和当前学习速率,作为输入,返回一个新的学习速率作为输出(浮点数)。

 

verbose:

0:

静默模式,1:

详细输出信息。

 

示例代码:

 

from keras.callbacks import LearningRateScheduler
scheduler = LearningRateScheduler(lambda x: 1. / (1. + x), verbose=0)

 

TensorBoard

 

TensorBoard是TensorFlow提供的可视化工具。

 

该回调写入可用于TensorBoard的日志,通过TensorBoard,可视化训练和测试度量的动态图形,以及模型中不同图层的激活直方图。

 

我们可以从命令行启动TensorBoard:

 

tensorboard --logdir = / full_path_to_your_logs

 

该回调的参数比较多,大部分情况下我们只用log_dir这个参数指定log存放的目录,其它参数并不需要了解,使用默认值即可:

 

from keras.callbacks import TensorBoard
tensorboard = TensorBoard(log_dir="logs/{}".format(time()))

 

自定义回调

 

创建自定义回调非常容易,通过扩展基类keras.callbacks.Callback来实现。回调可以通过类属性self.model访问其关联的模型。

 

下面是一个简单的示例,在训练期间保存每个epoch的损失列表:

 

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []
    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))
 = Sequential()
model.add(Dense(10, input_dim=784, kernel_initializer='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
history = LossHistory()
model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0, callbacks=[history])
print(history.losses)

 

输出结果:

 

[0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789]

 

小结

 

限于篇幅原因,本文只是介绍了Keras中常用的回调,通过这些示例,想必你已经理解了Keras中的回调,如果你希望详细了解keras中更多的内置回调,可以访问keras文档:

 

https://keras.io/callbacks/

 

参考:

 

 

Keras Callbacks Explained In Three Minutes

 

Usage of callbacks

 

Monitor progress of your Keras based neural network using Tensorboard

Be First to Comment

发表评论

邮箱地址不会被公开。 必填项已用*标注