Vanishing/Exploding gradients in RNN

Nitin SharmaNitin Sharma
6 min read

A basic Recurrent Neural Network (RNN) is a specialized type of artificial neural network designed to effectively process sequences of data, which is common in various fields such as natural language processing, time series analysis, and speech recognition. Unlike traditional feedforward neural networks, where information moves in one direction from input to output, RNNs incorporate connections that allow certain neurons to loop back onto themselves. This unique architecture enables RNNs to maintain a form of memory, which is crucial for understanding the context and dependencies in sequential data.

In an RNN, the input at each time step is not only processed independently but is also combined with the hidden state derived from the previous time step. The hidden state serves as a form of memory that captures relevant information from prior inputs in the sequence. This dynamic updating of the hidden state allows the network to incorporate both the current input and the context of previous inputs, effectively enabling the learning of temporal dependencies. Consequently, RNNs can adaptively handle sequences of varying lengths, making them particularly advantageous for tasks where the input size is not fixed, such as in natural language sentences or time-varying signals.

Basic Recurrent Neural Networks (RNNs) often encounter significant challenges related to the phenomena of vanishing and exploding gradients. The vanishing gradient problem arises when gradients become progressively smaller as they are propagated backward through the network during training, leading to inadequate updates to the weights of earlier layers. This makes it difficult for the network to learn long-term dependencies from the input sequences. Conversely, the exploding gradient problem occurs when gradients grow excessively large, causing sudden and erratic changes in the weights, which can destabilize the learning process. Both of these issues can severely hinder the performance of RNNs and limit their ability to effectively model sequential data. In this article, we will explore how backpropagation can lead to vanishing and exploding gradients. We will begin by examining a simple RNN architecture, which includes a feedback loop along with its associated weights and biases. Very simple RNN is shown below.

\(Image-1\)

To illustrate this, we will start with a basic design of an RNN, as shown below, to demonstrate the calculation of back propagation.

For this scenario we will use SSR ( Sum of Squared Residuals) as cost function. The sum of squared residuals serves as a cost function in various statistical models. It measures the discrepancy between observed values and the values predicted by the model. This cost function is calculated by taking the difference between each observed value and its corresponding predicted value (the residual), squaring each of those differences to eliminate negative values, and then summing all the squared differences together. The goal is to minimize this sum, which indicates that the model's predictions are closely aligned with the actual data. It can be defined as

$$SSR= \sum_i^m (Observed_i-Predicted_i)^2$$

Lets ignore feedback loop for now and just calculate the derivative of SSR with respect to W1 considering only Input3.

Applying chain rule we can say that

\(\begin{flalign*} & \frac{dSSR}{dW1}= \frac{dSSR}{dPredicted}.* \frac{dPredicted}{dW1}\space\space\space\space\cdots\cdots\cdots\cdots\cdots1 &\\ \end{flalign*}\)

First we calculate the derivative of SSR with respect to predicted value(output)

\(\begin{flalign*} & \frac{dSSR}{dpredicted}= \frac{d \sum_i^m (Observed_i-Predicted_i)^2}{dPredicted} &\\ \end{flalign*}\)

Applying The chain rule

\(\begin{flalign*} & \frac{dSSR}{dpredicted}= {\sum_i^m 2*(Observed_i-Predicted_i)} * -1 &\\ \end{flalign*}\)

\(\begin{flalign*} & \frac{dSSR}{dpredicted}= {\sum_i^m -2*(Observed_i-Predicted_i)} &\\ \end{flalign*}\)

Now lets calculate the derivative of Predicted value with respect to W1

\(\begin{flalign*} & \frac{dPredicted}{dW1}= \frac{d(W1*Input Value)}{dW1}=Input3 &\\ \end{flalign*}\)

So now the Equation 1 can be summarized as

\(\begin{flalign*} & \frac{dSSR}{dW1}= {\sum_i^m -2*(Observed_i-Predicted_i)} * Input3 &\\ \end{flalign*}\)

Now lets calculate derivative, when we unroll RNN to include previous Input as feedback as is shown below

When we unroll the RNN, the predicted value is the sum of previous value(input 2) W1 multiplied by W2 plus Input3 * W1

\(Predicted=(Input2 * W_1 * W_2)+(W_1*Input3)\)

\(\begin{flalign*} & \frac{dPredicted}{dW1}= \frac{d(Input2 * W_1*W_2)+(Input3*W_1)}{dW1}=(Input2 *W_2)+Input3 &\\ \end{flalign*}\)

If we consider one more previous input (Input1) like in Image-1 , then predicted will change to

\(Predicted=[(Input1 * W_1 * W_2)+(W_1*Input2)]*W_2+(Input3*W_1)\)

solving it

\(Predicted=(Input1 * W_1 * W_2^2)+(Input2*W_1*W_2)+(Input3*W_1)\)

\(\begin{flalign*} & \frac{dPredicted}{dW1}= \frac{d[(Input1 * W_1 * W_2)+(W_1*Input2)]*W_2+(Input3*W_1)}{dW1} &\\ \end{flalign*}\)

\(=(Input1*W_2^2)+(Input2*W_2)+Input3\)

Now lets replace the whole derivative of SSR with respect to W1

\(\begin{flalign*} & \frac{dSSR}{dW1}= {\sum_i^m -2*(Observed_i-Predicted_i)} * ((Input1*W_2^2)+(Input2*W_2)+Input3) &\\ \end{flalign*}\)

We see that there is a pattern of raising the power of \(W_2\)by the number of times we unroll the RNN to include previous input \(((Input1*W_2^2)\)

Vanishing and Exploding Gradients

Lets say we unroll RNN multiple times to include many previous values , way more than shown on Image-1

If the weights \(W_2\)is between -1 and 1 , then the derivative part \(((Input1*W_2^2)\) of the equation from \(\frac{dSSR}{dW_1}, \) will become very small, essentially referring to vanishing gradients; in other words, we can say that the contribution weight of previous values will disappear.

If the weights \(W_2\)is less than -1 and greater than 1 , then the derivative part \(((Input1*W_2^2)\) of t he equation from \(\frac{dSSR}{dW_1},\) will explode , it will be described as an exploding gradient, meaning the weights attributed to contributions from previous values will be exceedingly high.

The fundamental concept revolves around the inherent limitations of a basic Recurrent Neural Network (RNN) concerning the temporal dependencies it can effectively manage. Specifically, an RNN can only unroll for a limited number of time steps before the influence of older data points on the training process becomes problematic. When the sequence length exceeds this optimal range, the older inputs may either lose their significance—resulting in diminishing returns on their contribution to learning—or exert an overwhelming influence, thereby skewing the model's predictions and learning dynamics. This imbalance can hinder the model’s ability to retain relevant information over long sequences, ultimately affecting its performance on tasks that involve longer temporal dependencies.

Long Short-Term Memory networks(LSTM)

Long Short-Term Memory networks, commonly known as LSTMs, are a specialized type of recurrent neural network (RNN) designed to overcome the significant challenges of vanishing and exploding gradients that often occur in traditional RNNs during training. Vanishing gradients can make it difficult for the network to learn long-range dependencies in sequences, as the gradients used to update the model's weights become excessively small, effectively freezing the learning of earlier layers. Conversely, exploding gradients can lead to numerical instability and erratic updates, causing the model to diverge.

LSTMs address these issues through a unique architectural design featuring memory cells and three distinct gates: the input gate, the forget gate, and the output gate. The input gate regulates the flow of new information into the memory cell, the forget gate decides what information to discard from the memory cell, and the output gate controls the information that is sent out of the cell. This gating mechanism enables LSTMs to maintain information over extended sequences, allowing them to learn complex patterns and relationships in data, making them particularly effective for tasks such as language modeling, speech recognition, and time series forecasting.

We will cover LSTM in another article.

0
Subscribe to my newsletter

Read articles from Nitin Sharma directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Nitin Sharma
Nitin Sharma