学术, 机器学习

常见的梯度下降算法

反向传播是一个计算梯度的过程,它通过链式法则将损失函数关于网络参数的梯度从输出层向输入层传播。梯度下降则是一种基于梯度信息的优化算法,它通过沿着梯度的反方向调整参数值,以减小损失函数的值。

这里列出常见的几个梯度下降算法,按算法的复杂度排列为:SGD < SGDM < NAG < AdaGrad < RMSProp < AdaDelta < Adam < Nadam。目前最常用的是 Adam。

一、SGD

随机梯度下降 SGD(Stochastic Gradient Descent)为基本的梯度下降算法,每次迭代使用单个样本的梯度来更新参数。

PyTorch实现:torch.optim.SGD(params)

PyTorch文档:https://pytorch.org/docs/stable/generated/torch.optim.SGD.html

二、SGDM

带动量的随机梯度下降 SGDM(SGD with Momentum)。在 SGD 的基础上增加了一个动量项,该项考虑了上次的梯度更新的方向和速度,这样可以帮助加速收敛,并且有助于跳出局部最优解。直观的描述是:类似于物理学上的小球滚落过程,有一个动量或者惯性。

PyTorch实现:torch.optim.SGD(params, momentum)

PyTorch文档:同上。

三、NAG

NAG(Nesterov Accelerated Gradient)是一种改进的梯度下降算法。主要思想是在更新参数之前,首先利用动量来预测下一步的位置,然后在这个预测位置处计算梯度,这种做法可以减少动量带来的过度振荡,更有效地朝着最优解方向移动,从而加速收敛。直观的描述是:在小球滚落的过程中,可以提前预知前方情况,这样如果遇到了上升坡面,小球可以在之前提前减速。

PyTorch实现:torch.optim.SGD(params, momentum, nesterov=True)

PyTorch文档:同上。

四、AdaGrad

AdaGrad(Adaptive Gradient) 为自适应学习率的算法,根据每个参数的历史梯度调整学习率。随着训练的进行,学习率会逐渐变小,可能导致学习速度减慢。

PyTorch实现:torch.optim.Adagrad(params)

PyTorch文档:https://pytorch.org/docs/stable/generated/torch.optim.Adagrad.html

五、RMSProp

RMSProp(Root Mean Square Propagation) 也是一种自适应学习率算法,但相比 Adagrad,RMSprop 使用了一个衰减系数来限制历史梯度的影响,修正了 Adagrad 学习率下降过快的问题。

PyTorch实现:torch.optim.RMSprop(params)

PyTorch文档:https://pytorch.org/docs/stable/generated/torch.optim.RMSprop.html

六、AdaDelta

AdaDelta(Adaptive Delta)是 RMSProp 的一种改进版本,主要解决了 RMSProp 中需要手动设置初始学习率的问题。RMSProp 需要手动设置一个初始学习率,而 AdaDelta 完全去掉了学习率的概念,它使用了一个类似于 RMSProp 的梯度平方的指数加权平均来调整学习率,但不需要设置初始学习率。

PyTorch实现:torch.optim.Adadelta(params)

PyTorch文档:https://pytorch.org/docs/stable/generated/torch.optim.Adadelta.html

七、Adam

Adam(Adaptive Moment Estimation)结合了以上的动量和自适应学习率的优势,是一种流行的优化算法,通常表现出很好的性能,对于大多数问题都是一个良好的默认选择。

PyTorch实现:torch.optim.Adam(params)

PyTorch文档:https://pytorch.org/docs/stable/generated/torch.optim.Adam.html

八、Nadam

Nadam 是一种优化算法,是 Nesterov Accelerated Gradient(NAG)和 Adam 的结合体。

PyTorch实现:torch.optim.NAdam(params)

PyTorch文档:https://pytorch.org/docs/stable/generated/torch.optim.NAdam.html

附:两张动图

482 次浏览

【说明:本站主要是个人的一些笔记和代码分享,内容可能会不定期修改。为了使全网显示的始终是最新版本,这里的文章未经同意请勿转载。引用请注明出处:https://www.guanjihuan.com

发表评论

您的邮箱地址不会被公开。 必填项已用 * 标注

Captcha Code