Main Content

Predict and Update Network State in Simulink

This example shows how to predict responses for a trained recurrent neural network in Simulink® by using the Stateful Predict block. This example uses a pretrained long short-term memory (LSTM) network.

Load Pretrained Network

Load JapaneseVowelsNet, a pretrained long short-term memory (LSTM) network trained on the Japanese Vowels data set as described in [1] and [2]. This network was trained on the sequences sorted by sequence length with a mini-batch size of 27.

load JapaneseVowelsNet

View the network architecture.

ans = 

  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Load Test Data

Load the Japanese Vowels test data. XTest is a cell array containing 370 sequences of dimension 12 of varying length. YTest is a categorical vector of labels "1","2",..."9", which correspond to the nine speakers.

[XTest,YTest] = japaneseVowelsTestData;
X = XTest{94};
numTimeSteps = size(X,2);

Simulink Model for Predicting Responses

The Simulink model for predicting responses contains a Stateful Predict block to predict the scores and MATLAB Function blocks to load the input data sequence over the time steps.


Configure Model for Simulation

Set the model configuration parameters for the input blocks and the Stateful Predict block.

set_param('StatefulPredictExample/Stateful Predict','NetworkFilePath','JapaneseVowelsNet.mat');
set_param('StatefulPredictExample', 'SimulationMode', 'Normal');

Run the Simulation

To compute responses for the JapaneseVowelsNet network, run the simulation. The prediction scores are saved in the MATLAB® workspace.

out = sim('StatefulPredictExample');

Plot the prediction scores. The plot shows how the prediction scores change between time steps.

scores = squeeze(out.yPred.Data(:,:,1:numTimeSteps));

classNames = string(net.Layers(end).Classes);
lines = plot(scores');
xlim([1 numTimeSteps])
legend("Class " + classNames,'Location','northwest')
xlabel("Time Step")
title("Prediction Scores Over Time Steps")

Highlight the prediction scores over time steps for the correct class.

trueLabel = YTest(94);
lines(trueLabel).LineWidth = 3;

Display the final time step prediction in a bar chart.

title("Final Prediction Scores")


[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset.

See Also

| | |

Related Topics