转载https://blog.csdn.net/qq_32172681/article/details/94627366
训练过程中的本质就是在最小化损失,在定义损失之后,接下来就是训练网络参数了,优化器可以让神经网络更快收敛到最小值。
本文介绍几种 tensorflow 常用的优化器函数。
1、GradientDescentOptimizer
梯度下降算法需要用到全部样本,训练速度比较慢。
tf.train.GradientDescentOptimizer( learning_rate, use_locking=False, name="GradientDescent" )
2、AdagradOptimizer
自适应学习率,加入一个正则化项 ,对学习率进行约束,前期学习率小的时候,正则化项大,能够放大梯度;后期,学习率大的时候,正则化项大,可以减少梯度,适合处理稀疏数据。缺点:依赖于全局学习率。
tf.train.AdagradOptimizer( learning_rate, # 学习率 initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value, # 累计初始值 use_locking=False, name="Adagrad" )
3、AdadeltaOptimizer
Adadelta和Adagrad一样,都是自适应学习率,它是对Adagrad的改进,在计算上有所区别,Adagrad累加梯度的平方,依赖于全部学习率,而Adadelta加固定大小的值,不依赖于全局学习率。
tf.train.AdadeltaOptimizer( learning_rate, # 学习率 rho=FLAGS.adadelta_rho, # 衰减率 epsilon=FLAGS.opt_epsilon, # 用于更好的调节梯度更新的常量 use_locking=False, # 若为True,锁住更新操作 name="Adadelta" # 操作名 )
4、RMSPropOptimizer
它是Adagrad的改进、Adadelta的变体,仍然依赖于全局学习率,效果位于两者之间,对于RNN效果较好。
tf.train.RMSPropOptimizer( learning_rate, decay=FLAGS.rmsprop_decay, # 梯度的系数 momentum=FLAGS.rmsprop_momentum, # 动量 epsilon=FLAGS.opt_epsilon, # 用于更好的调节梯度更新的常量 use_locking=False, centered=False, # 如果为True,则通过梯度的估计方差对梯度进行归一化 name="RMSProp" )
5、MomentumOptimizer
就像物理上的动量一样,梯度大的时候,动量大,梯度小的时候,动量也会变小,能够更加平稳、快速地冲向局部最小点。
tf.train.MomentumOptimizer( learning_rate, # 学习率 momentum=FLAGS.momentum, # 动量 use_locking=False, name='Momentum', use_nesterov=False # 若为True,则使用Nesterov动量 )
6、AdamOptimizer
可以看作是带有Momentum的RMSProp,可以将学习率控制在一定范围内,参数较平稳。
tf.train.AdamOptimizer( learning_rate, # 学习率 beta1=FLAGS.adam_beta1, # 一阶矩估计衰减率 beta2=FLAGS.adam_beta2, # 二阶矩估计衰减率 epsilon=FLAGS.opt_epsilon, # 用于更好的调节梯度更新的常量 use_locking=False, name="Adam" )
7、FtrlOptimizer
FTRL 就是正则项为0的SGD算法。
tf.train.FtrlOptimizer( learning_rate, # 学习率 learning_rate_power=FLAGS.ftrl_learning_rate_power, # 控制训练期间学习率衰减方式 initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value, # 累计器初始值 l1_regularization_strength=FLAGS.ftrl_l1, # L1正则化系数 l2_regularization_strength=FLAGS.ftrl_l2, # L2正则化系数 use_locking=False, name="Ftrl", accum_name=None, linear_name=None, l2_shrinkage_regularization_strength=0.0 # 惩罚项 )