In the previous blog post, we learnt why we cannot use regular backpropagation to train a Recurrent Neural Network (RNN). We discussed how we can use backpropagation through time to train an RNN. The next step is to understand how exactly the RNN can be trained. Does the unrolling strategy work in practice? If we can just unroll an RNN and make it into a feedforward neural network, then what’s so special about the RNN in the first place? Let’s see how we tackle these issues.
What’s the problem?
When we unroll RNNs, we end up with a gigantic feedforward network. Training a regular sized feedforward neural network itself is a time consuming task. How are we going to train this enormously deep network that we obtained from unrolling the RNN? When we train a neural network, we use stochastic gradient descent as the optimization technique. This is basically a procedure to iteratively get closer to the optimal parameters in the network. One of the main problems with this technique is that it gets stuck in local optima. We also face the problem of vanishing (or exploding) gradients. This means that the gradients either become too large or too small during our iterations, which will make them useless for analysis.
Dealing with temporal dependencies
One of the main reasons we use sequential learning models is because we want them to connect previous information with the current event. For example, we want to find out what patterns in the past led to the event that’s happening right now. We want the RNN to be able to do this and do it really well.
The events that happen right now tend to depend on the past. But how far into the past are we supposed to look? Let’s take the same sentence that was discussed in the first blog post of this series. If I ask you to predict the blank word in the sentence — The brown dog has four ____, you will do it very easily. You can see that the answer is “legs”. Let’s consider another sentence — I’m attending the Robotics and Vision Conference this year, so I will be in ____ for a week. To predict the word, you need to have more context.
In the earlier case, the time gap between the event and the relevant information is small. RNNs can learn if the time gap is small. But if the time gap is large, as we saw in the second case, vanilla RNNs tend to falter. The reason is that it get progressively difficult for vanilla RNNs to connect that information to the event.
How do we solve the long term dependency problem?
To solve this problem, we need a special type of RNN that can handle long term dependencies. This is where Long Short Term Memory (abbreviated as LSTM) networks come into picture. The good thing about LSTMs is that they are specifically designed to handle the long term dependency issue. By their very nature, they will remember information for long periods of time. This is very important because we don’t have to explicitly build something extra to make sure the longer dependencies are taken care of. LSTMs do it on their own!
So how do they do that? The trick is in the transformation that happens after unrolling. In RNNs, we had a chain of units connected together after unrolling. These units have a really simple structure. In LSTMs, these units are designed in a special way to make sure that the LSTM retains memory for a long time. The underlying mathematics is actually very nice and not that difficult to understand. If you get a chance, you should look into it! In this blog post, we will try to understand LSTMs without mathematical equations and complicated diagrams.
How does LSTM work?
One of the central ideas of LSTMs is the cell state. Each cell has a state and it runs through the whole chain in the unrolled network. The goal is this state is just to let information flow across. One thing to note is that the LSTM has the ability to add or remove information to the cell state. This happens through “gates”. These gates are basically the logic gates that appear in digital electronics.
In each step, the first thing it needs to decide is what portion of the cell state needs to be maintained. The remaining portion will be discarded. This is decided together by current input and the previous output. The next step is to decide what new information is going to stored in the cell state. This step is actually divided into two parts. The first part deals with update, which means we need to decide what values in the cell state need to be updated. The second part deals with new information, which means we need to decide what new values will be added to the cell state. One we perform all these computations, it’s time to update the cell state.
Once the cell state is taken care of, we need to decide what’s the output of the current step. This step is basically the unit of the unrolled RNN. The output from this unit will just be a filtered version of the cell state. What exactly is this filter? The filter decides what values of the cell state will be pushed to output. Those values undergo a tanh (hyperbolic tangent) transformation before going through this filter. This output has a direct impact on the first step in the next unit i.e. what portion of the cell state needs to be maintained. You see how it all comes together nicely in the end? That’s the reason LSTMs are becoming increasingly popular.