这是我的第320篇原创文章。
一、引言
keras.callbacks
回调函数是一个函数的合集,会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。你可以传递一个列表的回调函数(作为 callbacks 关键字参数)到 Sequential 或 Model 类型的 .fit() 方法。在训练时,相应的回调函数的方法就会被在各自的阶段被调用。
这里有两个关键的点:
(1)状态和统计:其实就是我们希望模型在训练过程中需要从过程中获取什么信息,比如我的损失loss,准确率accuracy等信息就是训练过程中的状态与统计信息;再比如我希望每一个epoch结束之后打印一些相应的自定义提示信息,这也是状态信息。
(2)各自的阶段:模型的训练一般是分为多少个epoch,然后每一个epoch又分为多少个batch,所以这个阶段可以是在每一个epoch之后执行回调函数,也可以是在每一个batch之后执行回调函数。
虽然我们称之为回调“函数”,但事实上Keras的回调函数是一个类,回调函数只是习惯性称呼
keras.callbacks.Callback()
这是回调函数的抽象类,定义新的回调函数必须继承自该类。
系统预定义的回调函数:
二、实现过程
2.1History
History(训练可视化)
keras.callbacks.History()
该回调函数在Keras模型上会被自动调用,History对象即为fit方法的返回值,可以使用history中的存储的acc和loss数据对训练过程进行可视化画图。
示例代码:
class PrintDot(keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs):
if epoch % 100 == 0:
print('')
print('.', end='')
EPOCHS = 1000
model = build_model()
history = model.fit(normed_train_data, train_labels,epochs=EPOCHS, validation_split=0.2, verbose=0,callbacks=[PrintDot()])
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
print('\n', hist.tail())
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Abs Error [MPG]')
plt.plot(hist['epoch'], hist['mae'],
label='Train Error')
plt.plot(hist['epoch'], hist['val_mae'],
label='Val Error')
plt.ylim([0, 5])
plt.legend()
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Square Error [$MPG^2$]')
plt.plot(hist['epoch'], hist['mse'],
label='Train Error')
plt.plot(hist['epoch'], hist['val_mse'],
label='Val Error')
plt.ylim([0, 20])
plt.legend()
plt.show()
通过history绘制训练过程损失函数的变化:
定义新的回调函数PrintDot,继承keras.callbacks.Callback,传递给fit函数中的callbacks,实现每个完成的时期打印一个点来显示训练进度:
hist:
2.2 EarlyStopping
EarlyStopping
keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=0, verbose=0, mode=’auto’)
当监测值不再改善时,该回调函数将中止训练。
定义回调函数EarlyStopping,传递给fit函数中的callbacks,实现训练早停,代码示例:
model = build_model()
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
history = model.fit(normed_train_data, train_labels, epochs=EPOCHS,validation_split=0.2, verbose=0, callbacks=[early_stop, PrintDot()])
2.3 ModelCheckpoint
ModelCheckpoint
keras.callbacks.ModelCheckpoint(
filepath,
monitor='val_loss',
verbose=0,
save_best_only=True,
save_weights_only=False,
mode='auto',
period=1
)
该回调函数将在每个epoch后保存模型到filepath。
定义回调函数ModelCheckpoint,传递给fit函数中的callbacks,实现模型的保存与恢复,代码示例:代码示例:
filepath = "model_{epoch:02d}-{val_mse:.2f}.h5"
checkpoint = keras.callbacks.ModelCheckpoint(
filepath=filepath,
monitor='val_loss',
save_best_only=True,
verbose=1,
save_weights_only=True,
period=3
)
history = model.fit(normed_train_data, train_labels, epochs=EPOCHS,validation_split=0.2, verbose=0, callbacks=[checkpoint, PrintDot()])
结果:
作者简介: 读研期间发表6篇SCI数据算法相关论文,目前在某研究院从事数据算法相关研究工作,结合自身科研实践经历持续分享关于Python、数据分析、特征工程、机器学习、深度学习、人工智能系列基础知识与案例。关注gzh:数据杂坛,获取数据和源码学习更多内容。
原文链接:
本文暂时没有评论,来添加一个吧(●'◡'●)