学术, 机器学习

当损失函数长时间不再下降时自动停止训练的设置

在机器学习的训练过程中,除了通过多次测试获取经验来设置总的训练轮次,也可以使用额外的代码进行判断,当损失函数长时间不下降时,让它自动停止训练。本篇给出一个代码示例。

需要注意的是:

  • min_delta 和 patience 也是经验参数。如果 min_delta 过大或者 patience 过小,可能会使得程序过早停止训练,模型无法到达最优。如果 min_delta 过小或者 patience 过大,可能不会起到自动停止训练的目的。
  • patience 参数的值取决于 loss_array 的数量。如果 loss_array 只记录每次轮次(epoch)的 loss,那么 patience 可以比较小。如果 loss_array 记录每次迭代次数(iteration)的 loss,那么 patience 应该取大一些。参考:批量训练中迭代次数的计算。个人是推荐使用轮次(epoch)的 loss_array,这样不会受到批量大小(batch size)的影响。

代码示例:

"""
This code is supported by the website: https://www.guanjihuan.com
The newest version of this code is on the web page: https://www.guanjihuan.com/archives/41201
"""

def get_break_signal_from_loss_array(loss_array, patience=100, min_delta=0.001):
    break_signal = 0
    counter = 0
    num = len(loss_array)
    for i0 in range(num):
        if i0 != 0:
            if abs(loss_array[i0]-loss_array[i0-1])<min_delta:
                counter += 1
    if counter >= patience:  # 当损失函数的变化绝对值小于 min_delta 的次数超过 patience 次后,给一个停止信号
        break_signal = 1
    print(counter)  # 查看满足条件的次数
    return break_signal

train_times = 50
for i0 in range(train_times):
    print('Training...')
    loss_array = [10, 3, 1, 0.1, 0.02, 0.003, 0.001, 0.0004, 0.0005, 0.0001, 0.0003]
    break_signal = get_break_signal_from_loss_array(loss_array, patience=4, min_delta=0.001)
    if break_signal == 1:
        break
print('Early stop:', break_signal)

运行结果:

Training...
4
Early stop: 1
413 次浏览

【说明:本站主要是个人的一些笔记和代码分享,内容可能会不定期修改。为了使全网显示的始终是最新版本,这里的文章未经同意请勿转载。引用请注明出处:https://www.guanjihuan.com

发表评论

您的邮箱地址不会被公开。 必填项已用 * 标注

Captcha Code