Main Content

Predict and Update Network State in Simulink

This example shows how to predict responses for a trained recurrent neural network in Simulink® by using the Stateful Predict block. This example uses a pretrained long short-term memory (LSTM) network.

Load Pretrained Network

Load JapaneseVowelsNet, a pretrained long short-term memory (LSTM) network trained on the Japanese Vowels data set as described in [1] and [2]. This network was trained on the sequences sorted by sequence length with a mini-batch size of 27.

load JapaneseVowelsNet

View the network architecture.

net.Layers
ans = 

  5x1 Layer array with layers:

     1   'sequenceinput'   Sequence Input          Sequence input with 12 dimensions
     2   'lstm'            LSTM                    LSTM with 100 hidden units
     3   'fc'              Fully Connected         9 fully connected layer
     4   'softmax'         Softmax                 softmax
     5   'classoutput'     Classification Output   crossentropyex with '1' and 8 other classes

Load Test Data

Load the Japanese Vowels test data. XTest is a cell array containing 370 sequences of dimension 12 of varying length. YTest is a categorical vector of labels "1","2",..."9", which correspond to the nine speakers.

[XTest,YTest] = japaneseVowelsTestData;
X = XTest{94};
numTimeSteps = size(X,2);

Simulink Model for Predicting Responses

The Simulink model for predicting responses contains a Stateful Predict block to predict the scores and MATLAB Function blocks to load the input data sequence over the time steps.

open_system('StatefulPredictExample');

Configure Model for Simulation

Set the model configuration parameters for the input blocks and the Stateful Predict block.

set_param('StatefulPredictExample/Input','Value','X');
set_param('StatefulPredictExample/Index','uplimit','numTimeSteps-1');
set_param('StatefulPredictExample/Stateful Predict','NetworkFilePath','JapaneseVowelsNet.mat');
set_param('StatefulPredictExample', 'SimulationMode', 'Normal');

Run the Simulation

To compute responses for the JapaneseVowelsNet network, run the simulation. The prediction scores are saved in the MATLAB® workspace.

out = sim('StatefulPredictExample');

Plot the prediction scores. The plot shows how the prediction scores change between time steps.

scores = squeeze(out.yPred.Data(:,:,1:numTimeSteps));

classNames = string(net.Layers(end).Classes);
figure
lines = plot(scores');
xlim([1 numTimeSteps])
legend("Class " + classNames,'Location','northwest')
xlabel("Time Step")
ylabel("Score")
title("Prediction Scores Over Time Steps")

Highlight the prediction scores over time steps for the correct class.

trueLabel = YTest(94);
lines(trueLabel).LineWidth = 3;

Display the final time step prediction in a bar chart.

figure
bar(scores(:,end))
title("Final Prediction Scores")
xlabel("Class")
ylabel("Score")

References

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

See Also

| | |

Related Topics