Recurrent Neural Network (RNN)

A recurrent neural network (RNN) is a deep learning network structure that uses information of the past to improve the performance of the network on current and future inputs. What makes RNNs unique is that the network contains a hidden state and loops. The looping structure allows the network to store past information in the hidden state and operate on sequences.

These features of recurrent neural networks make them well suited for solving a variety of problems with sequential data of varying length such as:

Unrolling a single cell of an RNN, showing how information moves through the network for a data sequence. Inputs are acted on by the hidden state of the cell to produce the output, and the hidden state is passed to the next time step.

Unrolling a single cell of an RNN, showing how information moves through the network for a data sequence. Inputs are acted on by the hidden state of the cell to produce the output, and the hidden state is passed to the next time step.

How does the RNN know how to apply the past information to the current input? The network has two sets of weights, one for the hidden state vector and one for the inputs. During training, the network learns weights for both the inputs and the hidden state. When implemented, the output is based on the current input, as well as the hidden state, which is based on previous inputs.

LSTM

In practice, simple RNNs experience a problem with learning longer-term dependencies. RNNs are commonly trained through backpropagation, where they can experience either a ‘vanishing’ or ‘exploding’ gradient problem. These problem cause the network weights to either become very small or very large, limiting the effectiveness of learning the long-term relationships.

A special type of recurrent neural network that overcomes this issue is the long short-term memory (LSTM) network. LSTM networks use additional gates to control what information in the hidden cell makes it to the output and the next hidden state. This allows the network to more effectively learn long-term relationships in the data. LSTMs are a commonly implemented type of RNN.

Comparison of RNN (left) and LSTM network (right)

Comparison of RNN (left) and LSTM network (right)

MATLAB® has a full set of features and functionality to train and implement LSTM networks with text, image, signal, and time series data. The next sections will explore the applications of RNNs and some examples using MATLAB.

Applications of RNNs

Natural Language Processing

Language is naturally sequential, and pieces of text vary in length. This makes RNNs a great tool to solve problems in this area because they can learn to contextualize words in a sentence . One example includes sentiment analysis, a method for categorizing the meaning of words and phrases. Machine translation, or the use of an algorithm to translate between languages, is another common application. Words first need to be converted from text data into numeric sequences. An effective way of doing this is a word embedding layer. Word embeddings map words into numeric vectors. The example below uses word embeddings to train a word sentiment classifier, displaying the results with the  MATLAB wordcloud function.

Sentiment analysis results in MATLAB. The word cloud displays the results of the training process so the classifier can determine the sentiment of new groups of text.

Sentiment analysis results in MATLAB. The word cloud displays the results of the training process so the classifier can determine the sentiment of new groups of text.

In another classifier example, MATLAB uses RNNs to classify text data to determine the type of manufacturing failure.  MATLAB is also used in a machine translation example to train a network to understand Roman numerals.

Signal Classification

Signals are another example of naturally sequential data, as they are often collected from sensors over time. It is useful to automatically classify signals, as this can decrease the manual time needed for large datasets or allow classification in real time. Raw signal data can be fed into deep networks or pre-processed to focus on other features such as frequency components. Feature extraction can greatly improve network performance, as in an example with electrical heart signals. Below is an example using raw signal data in an RNN.

Classifying sensor data with an LSTM in MATLAB.

Classifying sensor data with an LSTM in MATLAB.

Video Analysis

RNNs work well for videos because videos are essentially a sequence of images. Similar to working with signals, it helps to do feature extraction before feeding the sequence into the RNN. In this example, a pretrained GoogleNet model (a convolutional neural network) is used for feature extraction on each frame. You can see the network architecture below.

Basic architecture for classifying video with LSTM.

Basic architecture for classifying video with LSTM.