Deep Learning For Sequential Data – Part IV: Training Recurrent Neural Networks

1 mainIn the previous blog post, we learnt how Recurrent Neural Networks (RNNs) can be used to build deep learning models for sequential data. Building a deep learning model involves many steps, and the training process is an important step. We should be able to train a model in a robust way in order to use it for inferencing. The training process needs to be trackable and it should converge in a reasonable amount of time. So how do we train RNNs? Can we just use the regular techniques that are used for feedforward neural networks?  

How are feedforward neural networks trained?

When we train a deep neural network, we need to go through the networks and iteratively updates the weights and biases of the network to get a nice model. In order to start training, we basically need a training dataset. This training dataset contains data and the associated labels. At the end of every cycle through the neural network, we compute the difference between the actual label and the predicted network output. This difference is quantified as the error of the network, also known as the cost function. The weights and biases of the neural network are updated to reduce this error before the next cycle. This process is repeated until the error becomes negligible.

Now how do we update the weights and biases? There are millions of weights and biases in a neural network, so how do we update all of them after each pass? We need to have a mathematical way to update the parameters of the network so that the overall error decreases in the next iteration. This is where backpropagation comes into picture.

What is backpropagation?

At the heart of training deep neural networks, we use a technique called backpropagation. This is the technique that controls how we update the parameters after each iteration. Ever since people found out about backpropagation, deep learning has expanded very rapidly. Backpropagation guides us in changing the weights and biases of a network so that we can get to the optimal values.

Using this technique, we can relate how changes in parameters affect the cost function. At its core, it is actually a set of equations to compute the rate of change of the cost function with respect to the weights and biases. We can calculate the updates to parameters of any layer using the parameters of the next layer. After each iteration through the network, we start at the last layer (output layer) and start computing updated weights and biases backwards. This is why it’s called backpropagation! This is an effective way of heading in the right direction as far as the weights and biases are concerned.

Can we use backpropagation for RNNs?

Not exactly! The reason is that we will encounter cyclical dependencies if we directly try to use backpropagation directly. In feedforward neural networks, we computed the updates to weights and biases completely using the parameters from the next layer. But when it comes to an RNN, we cannot do that because the output of the current neuron feeds into itself. This forms a cyclical relationship, hence making it intractable.

Having said that, backpropagation is absolutely necessary to train a neural network. So people came up with a nice trick! We can transform our RNN into something that can be trained using backpropagation. We need to make it look like a feedforward neural network in order to use backpropagation. This process is called unrolling. Once the RNN is unrolled, we can use backpropagation to train it. In this process, we basically take the neural network structure and replicate it for each step in the sequence. This process overall is called “backpropagation through time”. This technique has had a huge impact on the study of RNNs. In the next blog post, we will learn more about how it’s implemented in practice.

——————————————————————————————————————

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s