How Recurrent Neural Network works| LSTM | Deep Learning

 While most of the deep-learning problems can handle by artificial neural networks and convolutional neural networks, but in some areas, we need new approaches to tackle the problem. To be more precise, problems involving the relation between past and present, i.e., text generation, audio analysis, stock price prediction, etc.

Suppose you ask Google assistant, "Who is Tom cruise. Where he lives?" and it will tell you, "He is an American actor and lives in Beverly Hills." You must be thinking, "what's great in that?" We, humans, understand the contextual meaning of words, but that is not the case with computers. So in the above example, when we asked,  "Where he lives?" How does the assistant know we are still talking about Tom Cruise? To tackle these kinds of problems, we need recurrent neural networks because they can remember past events and make a meaningful relationship between past and present information. 


So what is a Recurrent Neural Network?

It is a type of neural network where the output of the previous state is feed as input for the current state. It allows the RNNs to have some memories that are lacking in ANNs and CNNs. Hence we can use it for the data having any sequence.

Recurrent Neural Network  works| LSTM | Deep Learning




Recurrent Neural Network  works| LSTM | Deep Learning




The above diagram is a standard representation of RNNs in deep learning, and we are going to stick to that for the sake of convenience. At time T, RNNs get the input from the two sources. First, from the data and another from the previous timestamp. So whatever output we are getting, it's routed back along with current data. This structure gives the RNNs, a short term memory, Hence we can find its uses in deep learning a lot.

Let's understand it with an example.

Suppose I asked you orally for the sum of five integers. It would take a fraction of a second to answer.  You were able to do that because first, you know how to sum, second you have a short term memory, to allow you to remember the numbers for a while.

Recurrent Neural Network  works| LSTM | Deep Learning


Similarly, we have five integers, and we want to find the sum. So we feed these numbers to the RNN at different timestamps, as you can see in the picture. Since RNNs have memory, they remember what output they got at a particular instant, and we get the desired result. It is a rudimentary example and aimed at making you understand the basic idea behind RNNs.

Mechanism of Recurrent Neural Networks

 Recurrent Neural Network has a recurrent core cell, so it takes some input and feeds that to RNNs, and these RNNs have hidden states. So every time it gets a new input vector, these hidden states get updated. These hidden states are again fed back to the RNNs with the next input vectors.

Recurrent Neural Network  works| LSTM | Deep Learning



The above figure represents the mathematical idea behind the recurrent neural network. So here :

  1. Xt is the input at the time T. Xt gets multiplied by the weight matrix Wxh.
  2. ht-1 is the hidden state at timestamp t-1. It also gets multiplied by the weight matrix Whh. This weight matrix is same across the network.
  3. To get the next hidden state(ht), we quash the two incoming inputs using a tanh or hyperbolic tangent function.
  4. To get the output at each timestamp, we multiple the ht with another weight matrix Wyh.

Types of Recurrent Neural Networks.

There are many variations of Recurrent neural networks based on their usage. The basic maths and idea are the same across the different RNNs.

Recurrent Neural Network  works| LSTM | Deep Learning



Training of Recurrent Neural Network.


Recurrent Neural Network  works| LSTM | Deep Learning




To understand the training of the recurrent neural network, we take an example of character-level language modeling. Here we try to predict the next character of the string. During the training, words of different lengths are feed to the network. Here we are taking a simple example of the word "FAME."
  • To convert these characters into the numeric format, first, we apply one-hot encoding. 
  • The one-hot encoded vector passed to the hidden layer along with weights attached.
  • The hidden layer squash the incomings using tanh or hyperbolic tangent function. The output of the hidden layer diverges. First, the output goes to the next hidden layer and second as an output along with some weights attached.
  • The weights attached output is then passed to the softmax function to get the probabilities of the next character in the sequence.
  • Recurrent neural networks also use backpropagation to learn from the data. I have talked about backpropagation in this post.
This process is repeated for the training set to train these weights. The above example is a simplified version of the complex training process. I have tried to keep it as simple as possible.

So How RNNs backpropagate?


Recurrent Neural Network  works| LSTM | Deep Learning


The idea is quite similar to the backpropagation in a fully connected neural network. We trained the model on batches of data instead of feeding the whole data at once because then it's computational inexpensive. Above is an example of loss calculated on a neural network. Here I have taken a batch of 5 data points, and the loss is optimized using backpropagation for these data points. These weights get transferred to the next data points, and the loss is optimized again using the backpropagation.

Recurrent Neural Network  works| LSTM | Deep Learning



The above figure represents how the backpropagation of information takes place at the cell level. So here, the red line indicates that weights are changed across the networks to minimize the loss. There arise problems due to this complex network known as vanishing gradient and exploding gradient.
In the gradient-based method, the network learns by analyzing how the change in parameters (weights) affect the output. So if there is a small change in the parameter, then this change will keep minimizing while backpropagating through the network (Vanishing gradient). If there is a substantial change in the parameters, this change will keep exploding or increasing while propagating through the network (exploding gradient).
You might have come across this motivational picture in your LinkedIn, Quora, or Instagram feed. It explains the problem of vanishing gradient and the exploding gradient well.

Recurrent Neural Network  works| LSTM | Deep Learning

Here is the project on Recurrent Neural Network where I built a model that predicts the stock price of MRF.




What is Long Short Term Memory ( LSTM ) Network?

 Traditional RNNs are good at retaining the information for the small sequence of data like the Tom cruise example. In the case of a long enough sequence, they fail to carry crucial information. This problem is known as the problem of vanishing gradient that we talked about earlier. To deal with this problem, we have LSTM networks. It is a type of recurrent neural network that solves this problem by introducing several gates. These gates allow the LSTM to have better control over what to forget and what to retain. In more technical terms, these gates control the gradient flow to deal with the problem of vanishing gradient.

Recurrent Neural Network  works| LSTM | Deep Learning



The above picture represents a layman's approach to understand the LSTM. We are going to start with building a basic intuition and jumping on the nitty-gritty of the network. In simple RNNs, we had only tanh function but, in LSTM, we introduce four gates and cell state to deal with long-term dependencies (Vanish gradient). So now we are feeding three things to the network instead of two:
  1.  Input at time t or data at that timestamp.
  2.  Hidden state or you can say short-term memory carrier.
  3.  Cell state or we can say long term memory carrier.
So these three inputs go through these four doors and outputs three things:
  1. Output at time t or prediction at that timestamp.
  2. Hidden state or short term memory gets updated based on the computation done in these four gates.
  3. Cell state or long term memory gets updated based on the computation done in those gates.

So this way, the LSTM networks are taking care of long term and short term memory. It is like this network is taking notes on which information should go into long-term memory and short-term memory.  Below we have a more technical representation of the LSTM network. Don't worry about these symbols and lines. We will talk about each of them thoroughly.

Recurrent Neural Network  works| LSTM | Deep Learning


What is Forget Gate?

First, we have the forget gate, which decides what information should not be passed or forget by the network. In the formula, we see that it takes input from the hidden state (ht-1) and data (Xt) at timestamp t and goes through a sigmoid function to output value between zero and one. So in a way, this gate is just a classifier. It classifies the information about whether to forget it or not. So here, zero means "completely get rid of the information" while one means "to keep it." So this gate is controlling how much we want to forget from the previous hidden state.


Recurrent Neural Network  works| LSTM | Deep Learning



Let's understand it by using an example, so if you ask Google Assitant: "Who is Tom Cruise, and where does he live?" After getting the answer ask again: "Who is Megan Fox, and where she lives?" In the latter case, we will get the information about Megan Fox. It was possible because as soon as the new gender came in, the network had to forget the previous gender.

What is Input gate ?

Once we know what information we have to forget, then we decide what to remember from the new data. This gate has two parts. First, we have a classifier or sigmoid function that predicts what information to update. Second, we have a tanh function that creates a vector of new candidate values to accommodate these changes.


Recurrent Neural Network  works| LSTM | Deep Learning



Taking the previous example, so forget gate network decided to forget Tom cruise as soon as we introduce Megan Fox. The LSTM network learned that "He" is no longer required, and the new candidate is "She" so focus on Megan Fox now.


What is Update Gate?

So far, we saw that we are only deciding what to forget and what to input. After this, we update the cell state, or we finally update all the information we have about the data. So Based on the present data, the LSTM network is deleting unnecessary information from the previous cell state (long-term memory) using the forget gate. Next, it finds useful information from the present data and decides to update that into the network. The updated information is carried forward to the next cell, which we called cell state (Ct) at time t. That is all, the equation is trying to tell.


Recurrent Neural Network  works| LSTM | Deep Learning



So once we introduce a new subject to the assistant, we are no longer required to call it a name. The assistant will understand even if we use the pronoun. So in the previous example, we can ask questions like:
  • Where she lives?
  • What is her profession?
  • Whom she Married?
, and the assistant will know that we are still talking about Megan Fox.

what is Output Gate ?

In this gate, we decide what to pass to the next hidden state and what model should output. The model doesn't need to output at each timestamp. It depends on the kind of problem we are dealing with. For example, in the case of character-level language modeling, we want an output at each timestamp.

Recurrent Neural Network  works| LSTM | Deep Learning



So here two things are happening. First, we are passing the output from the hidden state and current input to the sigmoid layer. Secondly, we are quashing the current cell state using tanh. In the next step, we multiply these two things to get the next hidden state also output for the current timestamp.
So in the celebrity example, once we have the updates, we need to carry it till the time we have new information. So as long as we don't introduce any new celebrity name, the assistant will keep thinking that we are talking about Megan Fox.

There are many variations of the LSTM network, and this blogpost has talked about them.  Here is the original paper by Hochreiter & Schmidhuber (1997). This blogpost by Andrej Kartpathy talks about the implementation of RNNs.

"That is all for this post, and if you have reached this far, I would like to thank you for investing your valuable time reading this. If you have any doubts or suggestions, please do let me know in the comment section."

You Might Also Like

3 Comments

  1. personally, I liked most, All the concepts covered with simple terms instead of going through thousands of pages all covered with minimum pages it's more recommended to people who have a basic understanding of data science ML/AI want to get the job it will help them to brush up their concepts, still, there are some things to cover in this overall well explained keep rocking...

    ReplyDelete

Please do not comment any spam link in the comment box