Why are LSTMs better than Elman RNNs?
Recurrent neural networks (RNNs) are useful for modelling sequential data of variable length, for example the words in a speech or stock price movements over a period of time. In this post, we assume that the reader is familiar with two of the most popular RNNs, the Elman RNN and the Long Short-Term Memory network (LSTM).
RNNs get more difficult to train as the length of the input sequence increases, which is why it is common to use truncated backpropagation through time to train RNNs.
The aim of this post is to explain why LSTMs can handle longer sequences than Elman RNNs.
General RNN
RNNs take the general form \(h_{t+1} = f(x_{t},h_{t})\). The hidden state \(h\) at each time step \(t\) is a function of the hidden state and the input \(x\) at the previous time step.
By differentiating the expression1, we realize that the gradient of the loss at \(h_{t}\) is the product of the gradient of the loss at \(h_{t+1}\) and the partial derivative of \(f\) with respect to \(h_{t}\).
\[\delta h_t = \delta h_{t+1} \cdot f'(x_{t},h_{t})\]If \(f' > 1\) on average, we end up with the exploding gradients problem, since repeated multiplication by a number above \(1\) tends to infinity.
If \(f' < 1\) on average, we end up with the vanishing gradients problem, since repeated multiplication by a positive number smaller than \(1\) tends to zero.
Elman RNNs
Elman RNNs instantiate \(f\) with the following equation.
\[h_{t+1} = \sigma_h (W_{x} x_{t} + W_{h} h_{t} + b)\]Applying the chain rule, we obtain \(f' = W_h^T Diag(\sigma_h')\) where \(Diag\) maps a vector into its diagonal matrix form.
The value of \(f'\) depends on both the weight matrix \(W_h\) and the choice of activation function \(\sigma_h\).
Where \(\sigma_h\) is the identity, the gradient explodes whenever the largest eigenvalue of \(W_h\) is above \(1\) and vanishes whenever it is below \(1\).
This makes it difficult to learn long range dependencies, since neither infinite nor zero gradients are helpful.
LSTMs
LSTMs instantiate \(f\) with the following set of equations.
\[f_t = \sigma_g (W_{fx} x_t + W_{fh} h_{t-1} + b_f)\] \[i_t = \sigma_g (W_{ix} x_t + W_{ih} h_{t-1} + b_i)\] \[o_t = \sigma_g (W_{ox} x_t + W_{oh} h_{t-1} + b_o)\] \[a_t = \sigma_c (W_{cx} x_t + W_{ch} h_{t-1} + b_c)\] \[c_t = f_t \circ c_{t-1} + i_t \circ a_t\] \[h_t = o_t \circ \sigma_h (c_t)\]Note: we are overloading \(f\) with both the forget gate \(f_t\) and the recurrent function \(f(x_t, c_{t-1})\). Here, the hidden (cell) state is \(c_t\) and not \(h_t\), which is the output state.
We can write \(f'\) as a sum of the following components using the chain rule again.
\[\begin{eqnarray} f' &=& \frac{\partial c_t}{\partial c_{t-1}} \nonumber \\ &=& Diag(f_t) + Diag(c_{t-1}) \frac{\partial f_t}{\partial c_{t-1}} + Diag(i_t) \frac{\partial a_t}{\partial c_{t-1}} + Diag(a_t) \frac{\partial i_t}{\partial c_{t-1}} \nonumber \end{eqnarray}\]The latter three terms in the sum depend on partial derivatives with respect to \(c_{t-1}\). Upon expanding those terms, we find that they are a (long) chain of weight matrices and derivatives of activation functions, as in the Elman RNN case.
\(\sigma_g\) is usually the sigmoid function, while \(\sigma_c\) and \(\sigma_h\) are usually the \(tanh\) function. The sigmoid function has a derivative upper bounded by \(\frac{1}{4}\) while the derivative of the \(tanh\) function is at most \(1\). This makes it likely for these terms to be small.
But crucially, the first term is just the activation of the forget gate at that timestep. This value is mostly independent of the eigenstructure of a weight matrix and the derivative of an activation function.
We can therefore mostly delay the vanishing/exploding gradients problem whenever \(f_t\) stays close to \(1\) even when the other three terms vanish (but not explode), which can be achieved with careful weight initialization.2
Hence, the forget gate is the main reason why LSTMs are better than Elman RNNs at processing long input sequences.
Thanks to Jenny Chen, Kui Tang and Jaan Altosaar for reading a draft of this post
-
To see this, it helps to rewrite it in a different notation: \(\frac{\partial L}{\partial h_{t}} = \frac{\partial L}{\partial h_{t+1}} \cdot \frac{\partial h_{t+1}}{\partial h_t}\) ↩
-
A Simple Way to Initialize Recurrent Networks of Rectified Linear Units. Quoc V. Le, Navdeep Jaitly, Geoffrey E. Hinton. ↩