Part of Advances in Neural Information Processing Systems 9 (NIPS 1996)
Genevieve Orr, Todd Leen
We present an algorithm for fast stochastic gradient descent that uses a nonlinear adaptive momentum scheme to optimize the late time convergence rate. The algorithm makes effective use of cur(cid:173) vature information, requires only O(n) storage and computation, and delivers convergence rates close to the theoretical optimum. We demonstrate the technique on linear and large nonlinear back(cid:173) prop networks.
Improving Stochastic Search
Learning algorithms that perform gradient descent on a cost function can be for(cid:173) mulated in either stochastic (on-line) or batch form. The stochastic version takes the form
Wt+l = Wt + J1.t G( Wt, Xt )
(1) where Wt is the current weight estimate, J1.t is the learning rate, G is minus the instantaneous gradient estimate, and Xt is the input at time t i . One obtains the corresponding batch mode learning rule by taking J1. constant and averaging Gover all x.
Stochastic learning provides several advantages over batch learning. For large datasets the batch average is expensive to compute. Stochastic learning eliminates the averaging. The stochastic update can be regarded as a noisy estimate of the batch update, and this intrinsic noise can reduce the likelihood of becoming trapped in poor local optima [1, 2J.
1 We assume that the inputs are i.i.d. This is achieved by random sampling with re(cid:173)
placement from the training data.
Using Curvature Informationfor Fast Stochastic Search
607
The noise must be reduced late in the training to allow weights to converge. After settling within the basin of a local optimum W., learning rate annealing allows con(cid:173) vergence of the weight error v == W - w •. It is well-known that the expected squared weight error, E[lv12] decays at its maximal rate ex: l/t with the annealing schedule flo/to FUrthermore to achieve this rate one must have flo > flcnt = 1/(2Am in) where Amin is the smallest eigenvalue of the Hessian at w. [3, 4, 5, and references therein]. Finally the optimal flo, which gives the lowest possible value of E[lv12] is flo = 1/ A. In multiple dimensions the optimal learning rate matrix is fl(t) = (l/t) 1-£-1 ,where 1-£ is the Hessian at the local optimum. Incorporating this curvature information into stochastic learning is difficult for two reasons. First, the Hessian is not available since the point of stochastic learning is not to perform averages over the training data. Second, even if the Hessian were available, optimal learning requires its inverse - which is prohibitively expensive to compute 2.
The primary result of this paper is that one can achieve an algorithm that behaves optimally, i.e. as if one had incorporated the inverse of the full Hessian, without the storage or computational burden. The algorithm, which requires only V(n) storage and computation (n = number of weights in the network), uses an adaptive momentum parameter, extending our earlier work [7] to fully non-linear problems. We demonstrate the performance on several large back-prop networks trained with large datasets.
Implementations of stochastic learning typically use a constant learning rate during the early part of training (what Darken and Moody [4] call the search phase) to ob(cid:173) tain exponential convergence towards a local optimum, and then switch to annealed learning (called the converge phase). We use Darken and Moody's adaptive search then converge (ASTC) algorithm to determine the point at which to switch to l/t annealing. ASTC was originally conceived as a means to insure flo > flcnt during the annealed phase, and we compare its performance with adaptive momentum as well. We also provide a comparison with conjugate gradient optimization.
1 Momentum in Stochastic Gradient Descent
The adaptive momentum algorithm we propose was suggested by earlier work on convergence rates for annealed learning with constant momentum. In this section we summarize the relevant results of that work.
Extending (1) to include momentum leaves the learning rule
wt+ 1 = Wt + flt G ( Wt, x t) + f3 ( Wt - Wt -1 )
(2) where f3 is the momentum parameter constrained so that 0 < f3 < 1. Analysis of the dynamics of the expected squared weight error E[ Ivl2 ] with flt = flo/t learning rate annealing [7, 8] shows that at late times, learning proceeds as for the algorithm without momentum, but with a scaled or effective learning rate
flo fleff = 1 _ f3