编程开源技术交流,分享技术与知识

网站首页 > 开源技术 正文

【Python深度学习系列】Callbacks使用-可视化、早停、保存与恢复

wxchong 2024-07-22 22:26:46 开源技术 73 ℃ 0 评论

这是我的第320篇原创文章。

一、引言

keras.callbacks

回调函数是一个函数的合集,会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。你可以传递一个列表的回调函数(作为 callbacks 关键字参数)到 Sequential 或 Model 类型的 .fit() 方法。在训练时,相应的回调函数的方法就会被在各自的阶段被调用

这里有两个关键的点:

(1)状态和统计:其实就是我们希望模型在训练过程中需要从过程中获取什么信息,比如我的损失loss,准确率accuracy等信息就是训练过程中的状态与统计信息;再比如我希望每一个epoch结束之后打印一些相应的自定义提示信息,这也是状态信息。

(2)各自的阶段:模型的训练一般是分为多少个epoch,然后每一个epoch又分为多少个batch,所以这个阶段可以是在每一个epoch之后执行回调函数,也可以是在每一个batch之后执行回调函数。

虽然我们称之为回调“函数”,但事实上Keras的回调函数是一个类,回调函数只是习惯性称呼

keras.callbacks.Callback()

这是回调函数的抽象类,定义新的回调函数必须继承自该类。

系统预定义的回调函数:

二、实现过程

2.1History

History(训练可视化

keras.callbacks.History()

该回调函数在Keras模型上会被自动调用History对象即为fit方法的返回值,可以使用history中的存储的acc和loss数据对训练过程进行可视化画图。

示例代码:

class PrintDot(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        if epoch % 100 == 0:
            print('')
        print('.', end='')

EPOCHS = 1000
model = build_model()
history = model.fit(normed_train_data, train_labels,epochs=EPOCHS, validation_split=0.2, verbose=0,callbacks=[PrintDot()])
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
print('\n', hist.tail())
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch


plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Abs Error [MPG]')
plt.plot(hist['epoch'], hist['mae'],
         label='Train Error')
plt.plot(hist['epoch'], hist['val_mae'],
         label='Val Error')
plt.ylim([0, 5])
plt.legend()


plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Square Error [$MPG^2$]')
plt.plot(hist['epoch'], hist['mse'],
         label='Train Error')
plt.plot(hist['epoch'], hist['val_mse'],
         label='Val Error')
plt.ylim([0, 20])
plt.legend()
plt.show()

通过history绘制训练过程损失函数的变化:

定义新的回调函数PrintDot,继承keras.callbacks.Callback,传递给fit函数中的callbacks,实现每个完成的时期打印一个点来显示训练进度:

hist:

2.2 EarlyStopping

EarlyStopping

keras.callbacks.EarlyStopping(monitor=’val_loss’, patience=0, verbose=0, mode=’auto’)

当监测值不再改善时,该回调函数将中止训练。

定义回调函数EarlyStopping,传递给fit函数中的callbacks,实现训练早停,代码示例:

model = build_model()
early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
history = model.fit(normed_train_data, train_labels, epochs=EPOCHS,validation_split=0.2, verbose=0, callbacks=[early_stop, PrintDot()])

2.3 ModelCheckpoint

ModelCheckpoint

keras.callbacks.ModelCheckpoint(
filepath,
monitor='val_loss',
verbose=0,
save_best_only=True,
save_weights_only=False,
mode='auto',
period=1
)

该回调函数将在每个epoch后保存模型到filepath。

定义回调函数ModelCheckpoint,传递给fit函数中的callbacks,实现模型的保存与恢复,代码示例:代码示例:

filepath = "model_{epoch:02d}-{val_mse:.2f}.h5"
checkpoint = keras.callbacks.ModelCheckpoint(
    filepath=filepath,
    monitor='val_loss',
    save_best_only=True,
    verbose=1,
    save_weights_only=True,
    period=3
)
history = model.fit(normed_train_data, train_labels, epochs=EPOCHS,validation_split=0.2, verbose=0, callbacks=[checkpoint, PrintDot()])

结果:

作者简介: 读研期间发表6篇SCI数据算法相关论文,目前在某研究院从事数据算法相关研究工作,结合自身科研实践经历持续分享关于Python、数据分析、特征工程、机器学习、深度学习、人工智能系列基础知识与案例。关注gzh:数据杂坛,获取数据和源码学习更多内容。

原文链接:

【Python深度学习系列】Keras回调函数Callbacks使用详解-训练过程可视化、早停、保存恢复(案例+源码)

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表